summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/errno/errno.go25
-rw-r--r--pkg/abi/linux/file.go4
-rw-r--r--pkg/abi/linux/ioctl.go15
-rw-r--r--pkg/abi/linux/msgqueue.go2
-rw-r--r--pkg/abi/linux/socket.go13
-rw-r--r--pkg/amutex/BUILD2
-rw-r--r--pkg/amutex/amutex.go6
-rw-r--r--pkg/atomicbitops/aligned_32bit_unsafe.go12
-rw-r--r--pkg/atomicbitops/atomicbitops_amd64.s40
-rw-r--r--pkg/atomicbitops/atomicbitops_arm64.s16
-rw-r--r--pkg/atomicbitops/atomicbitops_noasm.go8
-rw-r--r--pkg/buffer/view.go14
-rw-r--r--pkg/buffer/view_test.go18
-rw-r--r--pkg/cpuid/cpuid.go8
-rw-r--r--pkg/cpuid/cpuid_arm64.go14
-rw-r--r--pkg/cpuid/cpuid_x86.go15
-rw-r--r--pkg/crypto/crypto_stdlib.go15
-rw-r--r--pkg/errors/linuxerr/BUILD6
-rw-r--r--pkg/errors/linuxerr/internal.go120
-rw-r--r--pkg/errors/linuxerr/linuxerr.go8
-rw-r--r--pkg/errors/linuxerr/linuxerr_test.go90
-rw-r--r--pkg/eventchannel/BUILD1
-rw-r--r--pkg/eventchannel/event_any.go5
-rw-r--r--pkg/eventchannel/processor.go130
-rw-r--r--pkg/eventfd/BUILD22
-rw-r--r--pkg/eventfd/eventfd.go115
-rw-r--r--pkg/eventfd/eventfd_test.go75
-rw-r--r--pkg/flipcall/ctrl_futex.go5
-rw-r--r--pkg/flipcall/flipcall.go19
-rw-r--r--pkg/lisafs/BUILD117
-rw-r--r--pkg/lisafs/README.md3
-rw-r--r--pkg/lisafs/channel.go190
-rw-r--r--pkg/lisafs/client.go432
-rw-r--r--pkg/lisafs/client_file.go528
-rw-r--r--pkg/lisafs/communicator.go80
-rw-r--r--pkg/lisafs/connection.go320
-rw-r--r--pkg/lisafs/connection_test.go194
-rw-r--r--pkg/lisafs/fd.go374
-rw-r--r--pkg/lisafs/handlers.go768
-rw-r--r--pkg/lisafs/lisafs.go18
-rw-r--r--pkg/lisafs/message.go1251
-rw-r--r--pkg/lisafs/sample_message.go110
-rw-r--r--pkg/lisafs/server.go113
-rw-r--r--pkg/lisafs/sock.go208
-rw-r--r--pkg/lisafs/sock_test.go217
-rw-r--r--pkg/lisafs/testsuite/BUILD20
-rw-r--r--pkg/lisafs/testsuite/testsuite.go637
-rw-r--r--pkg/merkletree/merkletree.go91
-rw-r--r--pkg/p9/client.go4
-rw-r--r--pkg/p9/file.go16
-rw-r--r--pkg/p9/handlers.go2
-rw-r--r--pkg/p9/transport_flipcall.go10
-rw-r--r--pkg/ring0/defs.go12
-rw-r--r--pkg/ring0/defs_amd64.go5
-rw-r--r--pkg/ring0/entry_amd64.go5
-rw-r--r--pkg/ring0/entry_amd64.s177
-rw-r--r--pkg/ring0/kernel.go5
-rw-r--r--pkg/ring0/kernel_amd64.go23
-rw-r--r--pkg/ring0/lib_amd64.go42
-rw-r--r--pkg/ring0/lib_amd64.s23
-rw-r--r--pkg/ring0/offsets_amd64.go3
-rw-r--r--pkg/ring0/pagetables/pagetables.go9
-rw-r--r--pkg/safecopy/BUILD5
-rw-r--r--pkg/safecopy/safecopy.go16
-rw-r--r--pkg/safecopy/safecopy_unsafe.go37
-rw-r--r--pkg/safemem/io.go52
-rw-r--r--pkg/sentry/arch/fpu/BUILD1
-rw-r--r--pkg/sentry/arch/fpu/fpu.go13
-rw-r--r--pkg/sentry/arch/fpu/fpu_unsafe.go31
-rw-r--r--pkg/sentry/control/BUILD12
-rw-r--r--pkg/sentry/control/control.proto40
-rw-r--r--pkg/sentry/control/events.go65
-rw-r--r--pkg/sentry/control/pprof.go30
-rw-r--r--pkg/sentry/control/usage.go183
-rw-r--r--pkg/sentry/devices/memdev/BUILD2
-rw-r--r--pkg/sentry/devices/memdev/full.go6
-rw-r--r--pkg/sentry/devices/quotedev/BUILD16
-rw-r--r--pkg/sentry/devices/quotedev/quotedev.go52
-rw-r--r--pkg/sentry/devices/ttydev/BUILD2
-rw-r--r--pkg/sentry/devices/ttydev/ttydev.go4
-rw-r--r--pkg/sentry/fs/BUILD1
-rw-r--r--pkg/sentry/fs/copy_up.go21
-rw-r--r--pkg/sentry/fs/dev/BUILD1
-rw-r--r--pkg/sentry/fs/dev/full.go4
-rw-r--r--pkg/sentry/fs/dirent.go5
-rw-r--r--pkg/sentry/fs/fdpipe/BUILD2
-rw-r--r--pkg/sentry/fs/fdpipe/pipe.go7
-rw-r--r--pkg/sentry/fs/fdpipe/pipe_opener.go20
-rw-r--r--pkg/sentry/fs/fdpipe/pipe_opener_test.go15
-rw-r--r--pkg/sentry/fs/fdpipe/pipe_test.go11
-rw-r--r--pkg/sentry/fs/file.go50
-rw-r--r--pkg/sentry/fs/file_operations.go2
-rw-r--r--pkg/sentry/fs/file_overlay.go15
-rw-r--r--pkg/sentry/fs/fsutil/BUILD1
-rw-r--r--pkg/sentry/fs/fsutil/file.go13
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper.go21
-rw-r--r--pkg/sentry/fs/fsutil/inode.go11
-rw-r--r--pkg/sentry/fs/gofer/BUILD4
-rw-r--r--pkg/sentry/fs/gofer/file.go6
-rw-r--r--pkg/sentry/fs/gofer/gofer_test.go8
-rw-r--r--pkg/sentry/fs/gofer/inode.go8
-rw-r--r--pkg/sentry/fs/gofer/path.go7
-rw-r--r--pkg/sentry/fs/host/BUILD1
-rw-r--r--pkg/sentry/fs/host/file.go5
-rw-r--r--pkg/sentry/fs/host/inode.go5
-rw-r--r--pkg/sentry/fs/host/tty.go7
-rw-r--r--pkg/sentry/fs/host/util.go2
-rw-r--r--pkg/sentry/fs/inode.go3
-rw-r--r--pkg/sentry/fs/inode_operations.go2
-rw-r--r--pkg/sentry/fs/inode_overlay.go5
-rw-r--r--pkg/sentry/fs/inotify.go7
-rw-r--r--pkg/sentry/fs/proc/BUILD1
-rw-r--r--pkg/sentry/fs/proc/exec_args.go2
-rw-r--r--pkg/sentry/fs/proc/fds.go6
-rw-r--r--pkg/sentry/fs/proc/proc.go7
-rw-r--r--pkg/sentry/fs/proc/sys.go3
-rw-r--r--pkg/sentry/fs/proc/task.go127
-rw-r--r--pkg/sentry/fs/ramfs/BUILD1
-rw-r--r--pkg/sentry/fs/ramfs/dir.go5
-rw-r--r--pkg/sentry/fs/splice.go17
-rw-r--r--pkg/sentry/fs/timerfd/BUILD1
-rw-r--r--pkg/sentry/fs/timerfd/timerfd.go3
-rw-r--r--pkg/sentry/fs/tty/BUILD1
-rw-r--r--pkg/sentry/fs/tty/dir.go11
-rw-r--r--pkg/sentry/fs/tty/line_discipline.go10
-rw-r--r--pkg/sentry/fs/tty/queue.go6
-rw-r--r--pkg/sentry/fs/user/BUILD1
-rw-r--r--pkg/sentry/fs/user/path.go7
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/BUILD2
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/base.go9
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/cgroupfs.go181
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/memory.go30
-rw-r--r--pkg/sentry/fsimpl/devpts/BUILD1
-rw-r--r--pkg/sentry/fsimpl/devpts/devpts.go7
-rw-r--r--pkg/sentry/fsimpl/devpts/line_discipline.go10
-rw-r--r--pkg/sentry/fsimpl/devpts/queue.go6
-rw-r--r--pkg/sentry/fsimpl/eventfd/BUILD2
-rw-r--r--pkg/sentry/fsimpl/eventfd/eventfd.go10
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD0
-rw-r--r--pkg/sentry/fsimpl/fuse/BUILD2
-rw-r--r--pkg/sentry/fsimpl/fuse/dev.go15
-rw-r--r--pkg/sentry/fsimpl/fuse/dev_test.go4
-rw-r--r--pkg/sentry/fsimpl/fuse/directory.go12
-rw-r--r--pkg/sentry/fsimpl/fuse/fusefs.go3
-rw-r--r--pkg/sentry/fsimpl/fuse/read_write.go3
-rw-r--r--pkg/sentry/fsimpl/fuse/regular_file.go7
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD4
-rw-r--r--pkg/sentry/fsimpl/gofer/directory.go101
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go519
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go722
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer_test.go1
-rw-r--r--pkg/sentry/fsimpl/gofer/handle.go121
-rw-r--r--pkg/sentry/fsimpl/gofer/host_named_pipe.go3
-rw-r--r--pkg/sentry/fsimpl/gofer/p9file.go12
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go26
-rw-r--r--pkg/sentry/fsimpl/gofer/revalidate.go50
-rw-r--r--pkg/sentry/fsimpl/gofer/save_restore.go147
-rw-r--r--pkg/sentry/fsimpl/gofer/socket.go45
-rw-r--r--pkg/sentry/fsimpl/gofer/special_file.go154
-rw-r--r--pkg/sentry/fsimpl/gofer/symlink.go8
-rw-r--r--pkg/sentry/fsimpl/gofer/time.go5
-rw-r--r--pkg/sentry/fsimpl/host/BUILD1
-rw-r--r--pkg/sentry/fsimpl/host/host.go5
-rw-r--r--pkg/sentry/fsimpl/host/tty.go7
-rw-r--r--pkg/sentry/fsimpl/host/util.go2
-rw-r--r--pkg/sentry/fsimpl/kernfs/BUILD3
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go3
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go19
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go38
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go67
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go61
-rw-r--r--pkg/sentry/fsimpl/overlay/BUILD1
-rw-r--r--pkg/sentry/fsimpl/overlay/copy_up.go12
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go53
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD1
-rw-r--r--pkg/sentry/fsimpl/proc/subtasks.go17
-rw-r--r--pkg/sentry/fsimpl/proc/task_fds.go16
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go64
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go8
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_files.go5
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go9
-rw-r--r--pkg/sentry/fsimpl/signalfd/BUILD2
-rw-r--r--pkg/sentry/fsimpl/signalfd/signalfd.go4
-rw-r--r--pkg/sentry/fsimpl/sys/BUILD1
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go14
-rw-r--r--pkg/sentry/fsimpl/sys/sys_test.go14
-rw-r--r--pkg/sentry/fsimpl/timerfd/BUILD1
-rw-r--r--pkg/sentry/fsimpl/timerfd/timerfd.go3
-rw-r--r--pkg/sentry/fsimpl/tmpfs/BUILD2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go31
-rw-r--r--pkg/sentry/fsimpl/tmpfs/pipe_test.go3
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file_test.go6
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go3
-rw-r--r--pkg/sentry/fsimpl/verity/BUILD15
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go69
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go284
-rw-r--r--pkg/sentry/hostmm/BUILD3
-rw-r--r--pkg/sentry/hostmm/hostmm.go42
-rw-r--r--pkg/sentry/inet/BUILD13
-rw-r--r--pkg/sentry/kernel/BUILD2
-rw-r--r--pkg/sentry/kernel/auth/BUILD1
-rw-r--r--pkg/sentry/kernel/cgroup.go1
-rw-r--r--pkg/sentry/kernel/eventfd/BUILD2
-rw-r--r--pkg/sentry/kernel/eventfd/eventfd.go10
-rw-r--r--pkg/sentry/kernel/futex/BUILD1
-rw-r--r--pkg/sentry/kernel/futex/futex.go5
-rw-r--r--pkg/sentry/kernel/ipc/object.go35
-rw-r--r--pkg/sentry/kernel/kernel.go14
-rw-r--r--pkg/sentry/kernel/msgqueue/msgqueue.go121
-rw-r--r--pkg/sentry/kernel/pipe/BUILD2
-rw-r--r--pkg/sentry/kernel/pipe/node.go5
-rw-r--r--pkg/sentry/kernel/pipe/node_test.go3
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go9
-rw-r--r--pkg/sentry/kernel/pipe/pipe_test.go12
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go5
-rw-r--r--pkg/sentry/kernel/ptrace.go17
-rw-r--r--pkg/sentry/kernel/ptrace_amd64.go4
-rw-r--r--pkg/sentry/kernel/ptrace_arm64.go4
-rw-r--r--pkg/sentry/kernel/seccomp.go4
-rw-r--r--pkg/sentry/kernel/semaphore/BUILD3
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore.go23
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore_test.go10
-rw-r--r--pkg/sentry/kernel/shm/BUILD1
-rw-r--r--pkg/sentry/kernel/shm/shm.go26
-rw-r--r--pkg/sentry/kernel/signalfd/BUILD1
-rw-r--r--pkg/sentry/kernel/signalfd/signalfd.go3
-rw-r--r--pkg/sentry/kernel/task.go25
-rw-r--r--pkg/sentry/kernel/task_block.go11
-rw-r--r--pkg/sentry/kernel/task_clone.go38
-rw-r--r--pkg/sentry/kernel/task_exec.go12
-rw-r--r--pkg/sentry/kernel/task_exit.go13
-rw-r--r--pkg/sentry/kernel/task_image.go2
-rw-r--r--pkg/sentry/kernel/task_log.go12
-rw-r--r--pkg/sentry/kernel/task_net.go12
-rw-r--r--pkg/sentry/kernel/task_run.go6
-rw-r--r--pkg/sentry/kernel/task_signals.go11
-rw-r--r--pkg/sentry/kernel/task_start.go7
-rw-r--r--pkg/sentry/kernel/task_syscall.go7
-rw-r--r--pkg/sentry/kernel/task_usermem.go7
-rw-r--r--pkg/sentry/kernel/thread_group.go15
-rw-r--r--pkg/sentry/loader/BUILD1
-rw-r--r--pkg/sentry/loader/elf.go57
-rw-r--r--pkg/sentry/loader/interpreter.go8
-rw-r--r--pkg/sentry/loader/loader.go19
-rw-r--r--pkg/sentry/loader/vdso.go71
-rw-r--r--pkg/sentry/memmap/BUILD1
-rw-r--r--pkg/sentry/mm/BUILD1
-rw-r--r--pkg/sentry/mm/aio_context.go18
-rw-r--r--pkg/sentry/mm/mm.go1
-rw-r--r--pkg/sentry/mm/pma.go41
-rw-r--r--pkg/sentry/mm/syscalls.go47
-rw-r--r--pkg/sentry/mm/vma.go9
-rw-r--r--pkg/sentry/pgalloc/BUILD1
-rw-r--r--pkg/sentry/pgalloc/pgalloc.go3
-rw-r--r--pkg/sentry/platform/kvm/BUILD24
-rw-r--r--pkg/sentry/platform/kvm/bluepill.go7
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64.go8
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64.s27
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.go10
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.s34
-rw-r--r--pkg/sentry/platform/kvm/bluepill_fault.go6
-rw-r--r--pkg/sentry/platform/kvm/bluepill_unsafe.go32
-rw-r--r--pkg/sentry/platform/kvm/kvm.go6
-rw-r--r--pkg/sentry/platform/kvm/kvm_safecopy_test.go104
-rw-r--r--pkg/sentry/platform/kvm/machine.go148
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go36
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64_unsafe.go12
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go125
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go15
-rw-r--r--pkg/sentry/platform/kvm/machine_unsafe.go43
-rw-r--r--pkg/sentry/platform/kvm/physical_map.go3
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil_arm64.go4
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil_arm64.s35
-rw-r--r--pkg/sentry/seccheck/BUILD58
-rw-r--r--pkg/sentry/seccheck/clone.go53
-rw-r--r--pkg/sentry/seccheck/execve.go65
-rw-r--r--pkg/sentry/seccheck/exit.go57
-rw-r--r--pkg/sentry/seccheck/seccheck.go158
-rw-r--r--pkg/sentry/seccheck/seccheck_test.go157
-rw-r--r--pkg/sentry/seccheck/task.go39
-rw-r--r--pkg/sentry/socket/BUILD5
-rw-r--r--pkg/sentry/socket/control/control.go32
-rw-r--r--pkg/sentry/socket/control/control_test.go2
-rw-r--r--pkg/sentry/socket/hostinet/BUILD1
-rw-r--r--pkg/sentry/socket/hostinet/socket.go39
-rw-r--r--pkg/sentry/socket/netfilter/targets.go2
-rw-r--r--pkg/sentry/socket/netlink/BUILD1
-rw-r--r--pkg/sentry/socket/netlink/socket.go5
-rw-r--r--pkg/sentry/socket/netstack/BUILD3
-rw-r--r--pkg/sentry/socket/netstack/netstack.go200
-rw-r--r--pkg/sentry/socket/netstack/netstack_state.go31
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go7
-rw-r--r--pkg/sentry/socket/netstack/stack.go2
-rw-r--r--pkg/sentry/socket/socket.go40
-rw-r--r--pkg/sentry/socket/socket_state.go27
-rw-r--r--pkg/sentry/socket/unix/BUILD1
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go9
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go5
-rw-r--r--pkg/sentry/socket/unix/transport/queue.go10
-rw-r--r--pkg/sentry/socket/unix/unix.go11
-rw-r--r--pkg/sentry/strace/strace.go4
-rw-r--r--pkg/sentry/syscalls/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/error.go10
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go183
-rw-r--r--pkg/sentry/syscalls/linux/sigset.go3
-rw-r--r--pkg/sentry/syscalls/linux/sys_aio.go8
-rw-r--r--pkg/sentry/syscalls/linux/sys_epoll.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go38
-rw-r--r--pkg/sentry/syscalls/linux/sys_futex.go16
-rw-r--r--pkg/sentry/syscalls/linux/sys_getdents.go3
-rw-r--r--pkg/sentry/syscalls/linux/sys_lseek.go3
-rw-r--r--pkg/sentry/syscalls/linux/sys_mmap.go12
-rw-r--r--pkg/sentry/syscalls/linux/sys_msgqueue.go53
-rw-r--r--pkg/sentry/syscalls/linux/sys_poll.go14
-rw-r--r--pkg/sentry/syscalls/linux/sys_read.go25
-rw-r--r--pkg/sentry/syscalls/linux/sys_rlimit.go3
-rw-r--r--pkg/sentry/syscalls/linux/sys_rseq.go3
-rw-r--r--pkg/sentry/syscalls/linux/sys_sem.go19
-rw-r--r--pkg/sentry/syscalls/linux/sys_signal.go8
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go45
-rw-r--r--pkg/sentry/syscalls/linux/sys_splice.go11
-rw-r--r--pkg/sentry/syscalls/linux/sys_sync.go10
-rw-r--r--pkg/sentry/syscalls/linux/sys_syslog.go3
-rw-r--r--pkg/sentry/syscalls/linux/sys_thread.go5
-rw-r--r--pkg/sentry/syscalls/linux/sys_time.go7
-rw-r--r--pkg/sentry/syscalls/linux/sys_timer.go6
-rw-r--r--pkg/sentry/syscalls/linux/sys_tls_amd64.go3
-rw-r--r--pkg/sentry/syscalls/linux/sys_tls_arm64.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_write.go25
-rw-r--r--pkg/sentry/syscalls/linux/timespec.go9
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/aio.go3
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/execve.go3
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fd.go3
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/filesystem.go5
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/path.go3
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/poll.go17
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/read_write.go45
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/setstat.go5
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/socket.go48
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/splice.go25
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/stat.go5
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/sync.go6
-rw-r--r--pkg/sentry/syscalls/syscalls.go5
-rw-r--r--pkg/sentry/time/BUILD1
-rw-r--r--pkg/sentry/time/sampler_arm64.go4
-rw-r--r--pkg/sentry/vfs/BUILD2
-rw-r--r--pkg/sentry/vfs/README.md4
-rw-r--r--pkg/sentry/vfs/epoll.go5
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go19
-rw-r--r--pkg/sentry/vfs/file_description_impl_util_test.go5
-rw-r--r--pkg/sentry/vfs/inotify.go3
-rw-r--r--pkg/sentry/vfs/lock.go10
-rw-r--r--pkg/sentry/vfs/mount.go3
-rw-r--r--pkg/sentry/vfs/pathname.go4
-rw-r--r--pkg/sentry/vfs/permissions.go5
-rw-r--r--pkg/sentry/vfs/resolving_path.go9
-rw-r--r--pkg/sentry/vfs/vfs.go19
-rw-r--r--pkg/shim/runtimeoptions/runtimeoptions.go7
-rw-r--r--pkg/sighandling/BUILD (renamed from pkg/sentry/sighandling/BUILD)2
-rw-r--r--pkg/sighandling/sighandling.go (renamed from pkg/sentry/sighandling/sighandling.go)0
-rw-r--r--pkg/sighandling/sighandling_unsafe.go (renamed from pkg/sentry/sighandling/sighandling_unsafe.go)34
-rw-r--r--pkg/sync/atomicptr/generic_atomicptr_unsafe.go2
-rw-r--r--pkg/sync/atomicptrmap/BUILD1
-rw-r--r--pkg/syserr/BUILD2
-rw-r--r--pkg/syserr/syserr.go31
-rw-r--r--pkg/syserror/syserror.go180
-rw-r--r--pkg/tcpip/BUILD2
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go112
-rw-r--r--pkg/tcpip/checker/checker.go15
-rw-r--r--pkg/tcpip/header/eth.go6
-rw-r--r--pkg/tcpip/header/eth_test.go4
-rw-r--r--pkg/tcpip/internal/tcp/BUILD12
-rw-r--r--pkg/tcpip/internal/tcp/tcp.go48
-rw-r--r--pkg/tcpip/link/channel/channel.go24
-rw-r--r--pkg/tcpip/link/ethernet/ethernet.go28
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go3
-rw-r--r--pkg/tcpip/link/loopback/loopback.go31
-rw-r--r--pkg/tcpip/link/muxed/injectable.go5
-rw-r--r--pkg/tcpip/link/nested/nested.go15
-rw-r--r--pkg/tcpip/link/packetsocket/BUILD14
-rw-r--r--pkg/tcpip/link/packetsocket/endpoint.go50
-rw-r--r--pkg/tcpip/link/pipe/pipe.go3
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go10
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_amd64.s2
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_arm64.s2
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go32
-rw-r--r--pkg/tcpip/link/sharedmem/BUILD30
-rw-r--r--pkg/tcpip/link/sharedmem/queue/rx.go2
-rw-r--r--pkg/tcpip/link/sharedmem/queuepair.go199
-rw-r--r--pkg/tcpip/link/sharedmem/rx.go30
-rw-r--r--pkg/tcpip/link/sharedmem/server_rx.go142
-rw-r--r--pkg/tcpip/link/sharedmem/server_tx.go175
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go233
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_server.go333
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_server_test.go220
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go103
-rw-r--r--pkg/tcpip/link/sharedmem/tx.go20
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go5
-rw-r--r--pkg/tcpip/link/tun/BUILD1
-rw-r--r--pkg/tcpip/link/tun/device.go5
-rw-r--r--pkg/tcpip/link/waitable/waitable.go12
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go5
-rw-r--r--pkg/tcpip/network/BUILD2
-rw-r--r--pkg/tcpip/network/arp/arp.go1
-rw-r--r--pkg/tcpip/network/arp/arp_test.go16
-rw-r--r--pkg/tcpip/network/internal/testutil/testutil.go5
-rw-r--r--pkg/tcpip/network/ip_test.go175
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go105
-rw-r--r--pkg/tcpip/network/ipv4/igmp_test.go28
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go85
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go218
-rw-r--r--pkg/tcpip/network/ipv6/BUILD1
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go91
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go71
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go100
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go243
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go33
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go6
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go48
-rw-r--r--pkg/tcpip/network/multicast_group_test.go21
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go8
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go16
-rw-r--r--pkg/tcpip/socketops.go29
-rw-r--r--pkg/tcpip/stack/BUILD1
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go24
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state_test.go4
-rw-r--r--pkg/tcpip/stack/conntrack.go557
-rw-r--r--pkg/tcpip/stack/forwarding_test.go30
-rw-r--r--pkg/tcpip/stack/icmp_rate_limit.go39
-rw-r--r--pkg/tcpip/stack/iptables.go198
-rw-r--r--pkg/tcpip/stack/iptables_state.go4
-rw-r--r--pkg/tcpip/stack/iptables_targets.go150
-rw-r--r--pkg/tcpip/stack/iptables_types.go28
-rw-r--r--pkg/tcpip/stack/ndp_test.go107
-rw-r--r--pkg/tcpip/stack/nic.go148
-rw-r--r--pkg/tcpip/stack/nic_test.go5
-rw-r--r--pkg/tcpip/stack/packet_buffer.go70
-rw-r--r--pkg/tcpip/stack/packet_buffer_test.go12
-rw-r--r--pkg/tcpip/stack/registration.go47
-rw-r--r--pkg/tcpip/stack/stack.go122
-rw-r--r--pkg/tcpip/stack/stack_test.go624
-rw-r--r--pkg/tcpip/stack/tcp.go15
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go50
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go18
-rw-r--r--pkg/tcpip/stack/transport_test.go33
-rw-r--r--pkg/tcpip/tcpip.go54
-rw-r--r--pkg/tcpip/tcpip_state.go27
-rw-r--r--pkg/tcpip/tests/integration/BUILD4
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go20
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go386
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go24
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go47
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go20
-rw-r--r--pkg/tcpip/tests/integration/route_test.go38
-rw-r--r--pkg/tcpip/tests/utils/utils.go68
-rw-r--r--pkg/tcpip/transport/BUILD (renamed from pkg/syserror/BUILD)9
-rw-r--r--pkg/tcpip/transport/datagram.go49
-rw-r--r--pkg/tcpip/transport/icmp/BUILD2
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go437
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go35
-rw-r--r--pkg/tcpip/transport/icmp/icmp_test.go8
-rw-r--r--pkg/tcpip/transport/internal/network/BUILD46
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint.go811
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint_state.go58
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint_test.go318
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go252
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go24
-rw-r--r--pkg/tcpip/transport/raw/BUILD2
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go436
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go30
-rw-r--r--pkg/tcpip/transport/tcp/BUILD15
-rw-r--r--pkg/tcpip/transport/tcp/accept.go399
-rw-r--r--pkg/tcpip/transport/tcp/connect.go133
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go218
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go9
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go4
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go51
-rw-r--r--pkg/tcpip/transport/tcp/rack.go3
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go2
-rw-r--r--pkg/tcpip/transport/tcp/rcv_test.go2
-rw-r--r--pkg/tcpip/transport/tcp/segment_test.go2
-rw-r--r--pkg/tcpip/transport/tcp/snd.go141
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go24
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go265
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go684
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go8
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go89
-rw-r--r--pkg/tcpip/transport/transport.go16
-rw-r--r--pkg/tcpip/transport/udp/BUILD3
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go903
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go65
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go21
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go161
-rw-r--r--pkg/unet/BUILD1
-rw-r--r--pkg/unet/unet.go20
-rw-r--r--pkg/unet/unet_unsafe.go2
-rw-r--r--pkg/usermem/usermem.go2
499 files changed, 20273 insertions, 6240 deletions
diff --git a/pkg/abi/linux/errno/errno.go b/pkg/abi/linux/errno/errno.go
index 5a09c6605..38ebbb1d7 100644
--- a/pkg/abi/linux/errno/errno.go
+++ b/pkg/abi/linux/errno/errno.go
@@ -157,9 +157,32 @@ const (
EHWPOISON
)
-// errnos derived from other errnos
+// errnos derived from other errnos.
const (
EWOULDBLOCK = EAGAIN
EDEADLOCK = EDEADLK
ENONET = ENOENT
)
+
+// errnos for internal errors.
+const (
+ // ERESTARTSYS is returned by an interrupted syscall to indicate that it
+ // should be converted to EINTR if interrupted by a signal delivered to a
+ // user handler without SA_RESTART set, and restarted otherwise.
+ ERESTARTSYS = 512
+
+ // ERESTARTNOINTR is returned by an interrupted syscall to indicate that it
+ // should always be restarted.
+ ERESTARTNOINTR = 513
+
+ // ERESTARTNOHAND is returned by an interrupted syscall to indicate that it
+ // should be converted to EINTR if interrupted by a signal delivered to a
+ // user handler, and restarted otherwise.
+ ERESTARTNOHAND = 514
+
+ // ERESTART_RESTARTBLOCK is returned by an interrupted syscall to indicate
+ // that it should be restarted using a custom function. The interrupted
+ // syscall must register a custom restart function by calling
+ // Task.SetRestartSyscallFn.
+ ERESTART_RESTARTBLOCK = 516
+)
diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go
index 1e23850a9..67646f837 100644
--- a/pkg/abi/linux/file.go
+++ b/pkg/abi/linux/file.go
@@ -242,7 +242,7 @@ const (
// Statx represents struct statx.
//
-// +marshal
+// +marshal slice:StatxSlice
type Statx struct {
Mask uint32
Blksize uint32
@@ -270,6 +270,8 @@ type Statx struct {
var SizeOfStatx = (*Statx)(nil).SizeBytes()
// FileMode represents a mode_t.
+//
+// +marshal
type FileMode uint16
// Permissions returns just the permission bits.
diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go
index 006b5a525..29062c97a 100644
--- a/pkg/abi/linux/ioctl.go
+++ b/pkg/abi/linux/ioctl.go
@@ -170,3 +170,18 @@ const (
KCOV_MODE_TRACE_PC = 2
KCOV_MODE_TRACE_CMP = 3
)
+
+// Attestation ioctls.
+var (
+ SIGN_ATTESTATION_REPORT = IOC(_IOC_READ, 's', 1, 65)
+)
+
+// SizeOfQuoteInputData is the number of bytes in the input data of ioctl call
+// to get quote.
+const SizeOfQuoteInputData = 64
+
+// SignReport is a struct that gets signed quote from input data.
+type SignReport struct {
+ data [64]byte
+ quote []byte
+}
diff --git a/pkg/abi/linux/msgqueue.go b/pkg/abi/linux/msgqueue.go
index e1e8d0357..0612a8214 100644
--- a/pkg/abi/linux/msgqueue.go
+++ b/pkg/abi/linux/msgqueue.go
@@ -47,7 +47,7 @@ const (
MSGSSZ = 16
// MSGSEG is simplified due to the inexistance of a ternary operator.
- MSGSEG = (MSGPOOL * 1024) / MSGSSZ
+ MSGSEG = 0xffff
)
// MsqidDS is equivelant to struct msqid64_ds. Source:
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index 95871b8a5..f60e42997 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -542,6 +542,15 @@ type ControlMessageIPPacketInfo struct {
DestinationAddr InetAddr
}
+// ControlMessageIPv6PacketInfo represents struct in6_pktinfo from linux/ipv6.h.
+//
+// +marshal
+// +stateify savable
+type ControlMessageIPv6PacketInfo struct {
+ Addr Inet6Addr
+ NIC uint32
+}
+
// SizeOfControlMessageCredentials is the binary size of a
// ControlMessageCredentials struct.
var SizeOfControlMessageCredentials = (*ControlMessageCredentials)(nil).SizeBytes()
@@ -566,6 +575,10 @@ const SizeOfControlMessageTClass = 4
// control message.
const SizeOfControlMessageIPPacketInfo = 12
+// SizeOfControlMessageIPv6PacketInfo is the size of a
+// ControlMessageIPv6PacketInfo.
+const SizeOfControlMessageIPv6PacketInfo = 20
+
// SCM_MAX_FD is the maximum number of FDs accepted in a single sendmsg call.
// From net/scm.h.
const SCM_MAX_FD = 253
diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD
index bd3a5cce9..6d8b5f818 100644
--- a/pkg/amutex/BUILD
+++ b/pkg/amutex/BUILD
@@ -8,7 +8,7 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/context",
- "//pkg/syserror",
+ "//pkg/errors/linuxerr",
],
)
diff --git a/pkg/amutex/amutex.go b/pkg/amutex/amutex.go
index d7acc1d9f..985199cfa 100644
--- a/pkg/amutex/amutex.go
+++ b/pkg/amutex/amutex.go
@@ -20,7 +20,7 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
)
// Sleeper must be implemented by users of the abortable mutex to allow for
@@ -33,7 +33,7 @@ type NoopSleeper = context.Context
// Block blocks until either receiving from ch succeeds (in which case it
// returns nil) or sleeper is interrupted (in which case it returns
-// syserror.ErrInterrupted).
+// linuxerr.ErrInterrupted).
func Block(sleeper Sleeper, ch <-chan struct{}) error {
cancel := sleeper.SleepStart()
select {
@@ -42,7 +42,7 @@ func Block(sleeper Sleeper, ch <-chan struct{}) error {
return nil
case <-cancel:
sleeper.SleepFinish(false)
- return syserror.ErrInterrupted
+ return linuxerr.ErrInterrupted
}
}
diff --git a/pkg/atomicbitops/aligned_32bit_unsafe.go b/pkg/atomicbitops/aligned_32bit_unsafe.go
index a17d317cc..0e4765c48 100644
--- a/pkg/atomicbitops/aligned_32bit_unsafe.go
+++ b/pkg/atomicbitops/aligned_32bit_unsafe.go
@@ -39,9 +39,9 @@ type AlignedAtomicInt64 struct {
}
func (aa *AlignedAtomicInt64) ptr() *int64 {
- // On 32-bit systems, aa.value is is guaranteed to be 32-bit aligned.
- // It means that in the 12-byte aa.value, there are guaranteed to be 8
- // contiguous bytes with 64-bit alignment.
+ // On 32-bit systems, aa.value is guaranteed to be 32-bit aligned. It means
+ // that in the 12-byte aa.value, there are guaranteed to be 8 contiguous bytes
+ // with 64-bit alignment.
return (*int64)(unsafe.Pointer((uintptr(unsafe.Pointer(&aa.value)) + 4) &^ 7))
}
@@ -77,9 +77,9 @@ type AlignedAtomicUint64 struct {
}
func (aa *AlignedAtomicUint64) ptr() *uint64 {
- // On 32-bit systems, aa.value is is guaranteed to be 32-bit aligned.
- // It means that in the 12-byte aa.value, there are guaranteed to be 8
- // contiguous bytes with 64-bit alignment.
+ // On 32-bit systems, aa.value is guaranteed to be 32-bit aligned. It means
+ // that in the 12-byte aa.value, there are guaranteed to be 8 contiguous bytes
+ // with 64-bit alignment.
return (*uint64)(unsafe.Pointer((uintptr(unsafe.Pointer(&aa.value)) + 4) &^ 7))
}
diff --git a/pkg/atomicbitops/atomicbitops_amd64.s b/pkg/atomicbitops/atomicbitops_amd64.s
index 54c887ee5..6b9a67adc 100644
--- a/pkg/atomicbitops/atomicbitops_amd64.s
+++ b/pkg/atomicbitops/atomicbitops_amd64.s
@@ -16,28 +16,28 @@
#include "textflag.h"
-TEXT ·AndUint32(SB),$0-12
- MOVQ addr+0(FP), BP
+TEXT ·AndUint32(SB),NOSPLIT,$0-12
+ MOVQ addr+0(FP), BX
MOVL val+8(FP), AX
LOCK
- ANDL AX, 0(BP)
+ ANDL AX, 0(BX)
RET
-TEXT ·OrUint32(SB),$0-12
- MOVQ addr+0(FP), BP
+TEXT ·OrUint32(SB),NOSPLIT,$0-12
+ MOVQ addr+0(FP), BX
MOVL val+8(FP), AX
LOCK
- ORL AX, 0(BP)
+ ORL AX, 0(BX)
RET
-TEXT ·XorUint32(SB),$0-12
- MOVQ addr+0(FP), BP
+TEXT ·XorUint32(SB),NOSPLIT,$0-12
+ MOVQ addr+0(FP), BX
MOVL val+8(FP), AX
LOCK
- XORL AX, 0(BP)
+ XORL AX, 0(BX)
RET
-TEXT ·CompareAndSwapUint32(SB),$0-20
+TEXT ·CompareAndSwapUint32(SB),NOSPLIT,$0-20
MOVQ addr+0(FP), DI
MOVL old+8(FP), AX
MOVL new+12(FP), DX
@@ -46,28 +46,28 @@ TEXT ·CompareAndSwapUint32(SB),$0-20
MOVL AX, ret+16(FP)
RET
-TEXT ·AndUint64(SB),$0-16
- MOVQ addr+0(FP), BP
+TEXT ·AndUint64(SB),NOSPLIT,$0-16
+ MOVQ addr+0(FP), BX
MOVQ val+8(FP), AX
LOCK
- ANDQ AX, 0(BP)
+ ANDQ AX, 0(BX)
RET
-TEXT ·OrUint64(SB),$0-16
- MOVQ addr+0(FP), BP
+TEXT ·OrUint64(SB),NOSPLIT,$0-16
+ MOVQ addr+0(FP), BX
MOVQ val+8(FP), AX
LOCK
- ORQ AX, 0(BP)
+ ORQ AX, 0(BX)
RET
-TEXT ·XorUint64(SB),$0-16
- MOVQ addr+0(FP), BP
+TEXT ·XorUint64(SB),NOSPLIT,$0-16
+ MOVQ addr+0(FP), BX
MOVQ val+8(FP), AX
LOCK
- XORQ AX, 0(BP)
+ XORQ AX, 0(BX)
RET
-TEXT ·CompareAndSwapUint64(SB),$0-32
+TEXT ·CompareAndSwapUint64(SB),NOSPLIT,$0-32
MOVQ addr+0(FP), DI
MOVQ old+8(FP), AX
MOVQ new+16(FP), DX
diff --git a/pkg/atomicbitops/atomicbitops_arm64.s b/pkg/atomicbitops/atomicbitops_arm64.s
index 5c780851b..644a6bca5 100644
--- a/pkg/atomicbitops/atomicbitops_arm64.s
+++ b/pkg/atomicbitops/atomicbitops_arm64.s
@@ -16,7 +16,7 @@
#include "textflag.h"
-TEXT ·AndUint32(SB),$0-12
+TEXT ·AndUint32(SB),NOSPLIT,$0-12
MOVD ptr+0(FP), R0
MOVW val+8(FP), R1
again:
@@ -26,7 +26,7 @@ again:
CBNZ R3, again
RET
-TEXT ·OrUint32(SB),$0-12
+TEXT ·OrUint32(SB),NOSPLIT,$0-12
MOVD ptr+0(FP), R0
MOVW val+8(FP), R1
again:
@@ -36,7 +36,7 @@ again:
CBNZ R3, again
RET
-TEXT ·XorUint32(SB),$0-12
+TEXT ·XorUint32(SB),NOSPLIT,$0-12
MOVD ptr+0(FP), R0
MOVW val+8(FP), R1
again:
@@ -46,7 +46,7 @@ again:
CBNZ R3, again
RET
-TEXT ·CompareAndSwapUint32(SB),$0-20
+TEXT ·CompareAndSwapUint32(SB),NOSPLIT,$0-20
MOVD addr+0(FP), R0
MOVW old+8(FP), R1
MOVW new+12(FP), R2
@@ -60,7 +60,7 @@ done:
MOVW R3, prev+16(FP)
RET
-TEXT ·AndUint64(SB),$0-16
+TEXT ·AndUint64(SB),NOSPLIT,$0-16
MOVD ptr+0(FP), R0
MOVD val+8(FP), R1
again:
@@ -70,7 +70,7 @@ again:
CBNZ R3, again
RET
-TEXT ·OrUint64(SB),$0-16
+TEXT ·OrUint64(SB),NOSPLIT,$0-16
MOVD ptr+0(FP), R0
MOVD val+8(FP), R1
again:
@@ -80,7 +80,7 @@ again:
CBNZ R3, again
RET
-TEXT ·XorUint64(SB),$0-16
+TEXT ·XorUint64(SB),NOSPLIT,$0-16
MOVD ptr+0(FP), R0
MOVD val+8(FP), R1
again:
@@ -90,7 +90,7 @@ again:
CBNZ R3, again
RET
-TEXT ·CompareAndSwapUint64(SB),$0-32
+TEXT ·CompareAndSwapUint64(SB),NOSPLIT,$0-32
MOVD addr+0(FP), R0
MOVD old+8(FP), R1
MOVD new+16(FP), R2
diff --git a/pkg/atomicbitops/atomicbitops_noasm.go b/pkg/atomicbitops/atomicbitops_noasm.go
index 474c0c815..af6b1362e 100644
--- a/pkg/atomicbitops/atomicbitops_noasm.go
+++ b/pkg/atomicbitops/atomicbitops_noasm.go
@@ -21,6 +21,7 @@ import (
"sync/atomic"
)
+//go:nosplit
func AndUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -31,6 +32,7 @@ func AndUint32(addr *uint32, val uint32) {
}
}
+//go:nosplit
func OrUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -41,6 +43,7 @@ func OrUint32(addr *uint32, val uint32) {
}
}
+//go:nosplit
func XorUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -51,6 +54,7 @@ func XorUint32(addr *uint32, val uint32) {
}
}
+//go:nosplit
func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) {
for {
prev = atomic.LoadUint32(addr)
@@ -63,6 +67,7 @@ func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) {
}
}
+//go:nosplit
func AndUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -73,6 +78,7 @@ func AndUint64(addr *uint64, val uint64) {
}
}
+//go:nosplit
func OrUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -83,6 +89,7 @@ func OrUint64(addr *uint64, val uint64) {
}
}
+//go:nosplit
func XorUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -93,6 +100,7 @@ func XorUint64(addr *uint64, val uint64) {
}
}
+//go:nosplit
func CompareAndSwapUint64(addr *uint64, old, new uint64) (prev uint64) {
for {
prev = atomic.LoadUint64(addr)
diff --git a/pkg/buffer/view.go b/pkg/buffer/view.go
index 7bcfcd543..a4610f977 100644
--- a/pkg/buffer/view.go
+++ b/pkg/buffer/view.go
@@ -378,6 +378,20 @@ func (v *View) Copy() (other View) {
return
}
+// Clone makes a more shallow copy compared to Copy. The underlying payload
+// slice (buffer.data) is shared but the buffers themselves are copied.
+func (v *View) Clone() *View {
+ other := &View{
+ size: v.size,
+ }
+ for buf := v.data.Front(); buf != nil; buf = buf.Next() {
+ newBuf := other.pool.getNoInit()
+ *newBuf = *buf
+ other.data.PushBack(newBuf)
+ }
+ return other
+}
+
// Apply applies the given function across all valid data.
func (v *View) Apply(fn func([]byte)) {
for buf := v.data.Front(); buf != nil; buf = buf.Next() {
diff --git a/pkg/buffer/view_test.go b/pkg/buffer/view_test.go
index 796efa240..59784eacb 100644
--- a/pkg/buffer/view_test.go
+++ b/pkg/buffer/view_test.go
@@ -509,6 +509,24 @@ func TestView(t *testing.T) {
}
}
+func TestViewClone(t *testing.T) {
+ const (
+ originalSize = 90
+ bytesToDelete = 30
+ )
+ var v View
+ v.AppendOwned(bytes.Repeat([]byte{originalSize}, originalSize))
+
+ clonedV := v.Clone()
+ v.TrimFront(bytesToDelete)
+ if got, want := int(v.Size()), originalSize-bytesToDelete; got != want {
+ t.Errorf("original packet was not changed: size expected = %d, got = %d", want, got)
+ }
+ if got := clonedV.Size(); got != originalSize {
+ t.Errorf("cloned packet should not be modified: expected size = %d, got = %d", originalSize, got)
+ }
+}
+
func TestViewPullUp(t *testing.T) {
for _, tc := range []struct {
desc string
diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go
index 69eeb7528..4d5e062a8 100644
--- a/pkg/cpuid/cpuid.go
+++ b/pkg/cpuid/cpuid.go
@@ -37,6 +37,14 @@ package cpuid
// arch/arm64/include/uapi/asm/hwcap.h
type Feature int
+// HostFeatureSet returns a FeatureSet that matches that of the host machine.
+// Callers must not mutate the returned FeatureSet.
+func HostFeatureSet() *FeatureSet {
+ return hostFeatureSet
+}
+
+var hostFeatureSet = getHostFeatureSet()
+
// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a
// subset of the host feature set.
type ErrIncompatible struct {
diff --git a/pkg/cpuid/cpuid_arm64.go b/pkg/cpuid/cpuid_arm64.go
index 6e61d562f..04645aed5 100644
--- a/pkg/cpuid/cpuid_arm64.go
+++ b/pkg/cpuid/cpuid_arm64.go
@@ -230,6 +230,16 @@ type FeatureSet struct {
CPURevision uint8
}
+// Clone returns a copy of fs.
+func (fs *FeatureSet) Clone() *FeatureSet {
+ fs2 := *fs
+ fs2.Set = make(map[Feature]bool)
+ for f, b := range fs.Set {
+ fs2.Set[f] = b
+ }
+ return &fs2
+}
+
// CheckHostCompatible returns nil if fs is a subset of the host feature set.
// Noop on arm64.
func (fs *FeatureSet) CheckHostCompatible() error {
@@ -292,9 +302,9 @@ func (fs FeatureSet) WriteCPUInfoTo(cpu uint, b *bytes.Buffer) {
fmt.Fprintln(b, "") // The /proc/cpuinfo file ends with an extra newline.
}
-// HostFeatureSet uses hwCap to get host values and construct a feature set
+// getHostFeatureSet uses hwCap to get host values and construct a feature set
// that matches that of the host machine.
-func HostFeatureSet() *FeatureSet {
+func getHostFeatureSet() *FeatureSet {
s := make(map[Feature]bool)
for f := range arm64FeatureStrings {
diff --git a/pkg/cpuid/cpuid_x86.go b/pkg/cpuid/cpuid_x86.go
index dc17cade8..a92d32d74 100644
--- a/pkg/cpuid/cpuid_x86.go
+++ b/pkg/cpuid/cpuid_x86.go
@@ -627,6 +627,17 @@ type FeatureSet struct {
CacheLine uint32
}
+// Clone returns a copy of fs.
+func (fs *FeatureSet) Clone() *FeatureSet {
+ fs2 := *fs
+ fs2.Set = make(map[Feature]bool)
+ for f, b := range fs.Set {
+ fs2.Set[f] = b
+ }
+ fs2.Caches = append([]Cache(nil), fs.Caches...)
+ return &fs2
+}
+
// FlagsString prints out supported CPU flags. If cpuinfoOnly is true, it is
// equivalent to the "flags" field in /proc/cpuinfo.
func (fs *FeatureSet) FlagsString(cpuinfoOnly bool) string {
@@ -961,13 +972,13 @@ func (fs *FeatureSet) UseXsaveopt() bool {
// HostID executes a native CPUID instruction.
func HostID(axArg, cxArg uint32) (ax, bx, cx, dx uint32)
-// HostFeatureSet uses cpuid to get host values and construct a feature set
+// getHostFeatureSet uses cpuid to get host values and construct a feature set
// that matches that of the host machine. Note that there are several places
// where there appear to be some unnecessary assignments between register names
// (ax, bx, cx, or dx) and featureBlockN variables. This is to explicitly show
// where the different feature blocks come from, to make the code easier to
// inspect and read.
-func HostFeatureSet() *FeatureSet {
+func getHostFeatureSet() *FeatureSet {
// eax=0 gets max supported feature and vendor ID.
_, bx, cx, dx := HostID(0, 0)
vendorID := vendorIDFromRegs(bx, cx, dx)
diff --git a/pkg/crypto/crypto_stdlib.go b/pkg/crypto/crypto_stdlib.go
index 69e867386..28eba2ff6 100644
--- a/pkg/crypto/crypto_stdlib.go
+++ b/pkg/crypto/crypto_stdlib.go
@@ -19,14 +19,21 @@ package crypto
import (
"crypto/ecdsa"
+ "crypto/elliptic"
"crypto/sha512"
+ "fmt"
"math/big"
)
-// EcdsaVerify verifies the signature in r, s of hash using ECDSA and the
-// public key, pub. Its return value records whether the signature is valid.
-func EcdsaVerify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) (bool, error) {
- return ecdsa.Verify(pub, hash, r, s), nil
+// EcdsaP384Sha384Verify verifies the signature in r, s of hash using ECDSA
+// P384 + SHA 384 and the public key, pub. Its return value records whether
+// the signature is valid.
+func EcdsaP384Sha384Verify(pub *ecdsa.PublicKey, data []byte, r, s *big.Int) (bool, error) {
+ if pub.Curve != elliptic.P384() {
+ return false, fmt.Errorf("unsupported key curve: want P-384, got %v", pub.Curve)
+ }
+ digest := sha512.Sum384(data)
+ return ecdsa.Verify(pub, digest[:], r, s), nil
}
// SumSha384 returns the SHA384 checksum of the data.
diff --git a/pkg/errors/linuxerr/BUILD b/pkg/errors/linuxerr/BUILD
index 201727780..e73b0e28a 100644
--- a/pkg/errors/linuxerr/BUILD
+++ b/pkg/errors/linuxerr/BUILD
@@ -4,7 +4,10 @@ package(licenses = ["notice"])
go_library(
name = "linuxerr",
- srcs = ["linuxerr.go"],
+ srcs = [
+ "internal.go",
+ "linuxerr.go",
+ ],
visibility = ["//visibility:public"],
deps = [
"//pkg/abi/linux/errno",
@@ -20,7 +23,6 @@ go_test(
":linuxerr",
"//pkg/abi/linux/errno",
"//pkg/errors",
- "//pkg/syserror",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/errors/linuxerr/internal.go b/pkg/errors/linuxerr/internal.go
new file mode 100644
index 000000000..127bba0df
--- /dev/null
+++ b/pkg/errors/linuxerr/internal.go
@@ -0,0 +1,120 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License"),;
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linuxerr
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux/errno"
+ "gvisor.dev/gvisor/pkg/errors"
+)
+
+var (
+ // ErrWouldBlock is an internal error used to indicate that an operation
+ // cannot be satisfied immediately, and should be retried at a later
+ // time, possibly when the caller has received a notification that the
+ // operation may be able to complete. It is used by implementations of
+ // the kio.File interface.
+ ErrWouldBlock = errors.New(errno.EWOULDBLOCK, "request would block")
+
+ // ErrInterrupted is returned if a request is interrupted before it can
+ // complete.
+ ErrInterrupted = errors.New(errno.EINTR, "request was interrupted")
+
+ // ErrExceedsFileSizeLimit is returned if a request would exceed the
+ // file's size limit.
+ ErrExceedsFileSizeLimit = errors.New(errno.E2BIG, "exceeds file size limit")
+)
+
+var errorMap = map[error]*errors.Error{
+ ErrWouldBlock: EWOULDBLOCK,
+ ErrInterrupted: EINTR,
+ ErrExceedsFileSizeLimit: EFBIG,
+}
+
+// errorUnwrappers is an array of unwrap functions to extract typed errors.
+var errorUnwrappers = []func(error) (*errors.Error, bool){}
+
+// AddErrorUnwrapper registers an unwrap method that can extract a concrete error
+// from a typed, but not initialized, error.
+func AddErrorUnwrapper(unwrap func(e error) (*errors.Error, bool)) {
+ errorUnwrappers = append(errorUnwrappers, unwrap)
+}
+
+// TranslateError translates errors to errnos, it will return false if
+// the error was not registered.
+func TranslateError(from error) (*errors.Error, bool) {
+ if err, ok := errorMap[from]; ok {
+ return err, true
+ }
+ // Try to unwrap the error if we couldn't match an error
+ // exactly. This might mean that a package has its own
+ // error type.
+ for _, unwrap := range errorUnwrappers {
+ if err, ok := unwrap(from); ok {
+ return err, true
+ }
+ }
+ return nil, false
+}
+
+// These errors are significant because ptrace syscall exit tracing can
+// observe them.
+//
+// For all of the following errors, if the syscall is not interrupted by a
+// signal delivered to a user handler, the syscall is restarted.
+var (
+ // ERESTARTSYS is returned by an interrupted syscall to indicate that it
+ // should be converted to EINTR if interrupted by a signal delivered to a
+ // user handler without SA_RESTART set, and restarted otherwise.
+ ERESTARTSYS = errors.New(errno.ERESTARTSYS, "to be restarted if SA_RESTART is set")
+
+ // ERESTARTNOINTR is returned by an interrupted syscall to indicate that it
+ // should always be restarted.
+ ERESTARTNOINTR = errors.New(errno.ERESTARTNOINTR, "to be restarted")
+
+ // ERESTARTNOHAND is returned by an interrupted syscall to indicate that it
+ // should be converted to EINTR if interrupted by a signal delivered to a
+ // user handler, and restarted otherwise.
+ ERESTARTNOHAND = errors.New(errno.ERESTARTNOHAND, "to be restarted if no handler")
+
+ // ERESTART_RESTARTBLOCK is returned by an interrupted syscall to indicate
+ // that it should be restarted using a custom function. The interrupted
+ // syscall must register a custom restart function by calling
+ // Task.SetRestartSyscallFn.
+ ERESTART_RESTARTBLOCK = errors.New(errno.ERESTART_RESTARTBLOCK, "interrupted by signal")
+)
+
+var restartMap = map[int]*errors.Error{
+ -int(errno.ERESTARTSYS): ERESTARTSYS,
+ -int(errno.ERESTARTNOINTR): ERESTARTNOINTR,
+ -int(errno.ERESTARTNOHAND): ERESTARTNOHAND,
+ -int(errno.ERESTART_RESTARTBLOCK): ERESTART_RESTARTBLOCK,
+}
+
+// IsRestartError checks if a given error is a restart error.
+func IsRestartError(err error) bool {
+ switch err {
+ case ERESTARTSYS, ERESTARTNOINTR, ERESTARTNOHAND, ERESTART_RESTARTBLOCK:
+ return true
+ default:
+ return false
+ }
+}
+
+// SyscallRestartErrorFromReturn returns the SyscallRestartErrno represented by
+// rv, the value in a syscall return register.
+func SyscallRestartErrorFromReturn(rv uintptr) (*errors.Error, bool) {
+ err, ok := restartMap[int(rv)]
+ return err, ok
+}
diff --git a/pkg/errors/linuxerr/linuxerr.go b/pkg/errors/linuxerr/linuxerr.go
index f9f8412e0..5905ef593 100644
--- a/pkg/errors/linuxerr/linuxerr.go
+++ b/pkg/errors/linuxerr/linuxerr.go
@@ -27,6 +27,12 @@ import (
const maxErrno uint32 = errno.EHWPOISON + 1
+// The following errors are semantically identical to Errno of type unix.Errno
+// or sycall.Errno. However, since the type are distinct ( these are
+// *errors.Error), they are not directly comperable. However, the Errno method
+// returns an Errno number such that the error can be compared to unix/syscall.Errno
+// (e.g. unix.Errno(EPERM.Errno()) == unix.EPERM is true). Converting unix/syscall.Errno
+// to the errors should be done via the lookup methods provided.
var (
NOERROR = errors.New(errno.NOERRNO, "not an error")
EPERM = errors.New(errno.EPERM, "operation not permitted")
@@ -177,7 +183,7 @@ var (
var errNotValidError = errors.New(errno.Errno(maxErrno), "not a valid error")
// The following errorSlice holds errors by errno for fast translation between
-// errnos (especially uint32(sycall.Errno)) and *Error.
+// errnos (especially uint32(sycall.Errno)) and *errors.Error.
var errorSlice = []*errors.Error{
// Errno values from include/uapi/asm-generic/errno-base.h.
errno.NOERRNO: NOERROR,
diff --git a/pkg/errors/linuxerr/linuxerr_test.go b/pkg/errors/linuxerr/linuxerr_test.go
index f09d61b02..df7cd1c5a 100644
--- a/pkg/errors/linuxerr/linuxerr_test.go
+++ b/pkg/errors/linuxerr/linuxerr_test.go
@@ -12,10 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package syserror_test
+package linuxerr_test
import (
"errors"
+ "fmt"
"io"
"io/fs"
"syscall"
@@ -25,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux/errno"
gErrors "gvisor.dev/gvisor/pkg/errors"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
- "gvisor.dev/gvisor/pkg/syserror"
)
var globalError error
@@ -42,12 +42,6 @@ func BenchmarkAssignLinuxerr(b *testing.B) {
}
}
-func BenchmarkAssignSyserror(b *testing.B) {
- for i := b.N; i > 0; i-- {
- globalError = linuxerr.ENOMSG
- }
-}
-
func BenchmarkCompareUnix(b *testing.B) {
globalError = unix.EAGAIN
j := 0
@@ -68,16 +62,6 @@ func BenchmarkCompareLinuxerr(b *testing.B) {
}
}
-func BenchmarkCompareSyserror(b *testing.B) {
- globalError = linuxerr.EAGAIN
- j := 0
- for i := b.N; i > 0; i-- {
- if globalError == linuxerr.EACCES {
- j++
- }
- }
-}
-
func BenchmarkSwitchUnix(b *testing.B) {
globalError = unix.EPERM
j := 0
@@ -108,21 +92,6 @@ func BenchmarkSwitchLinuxerr(b *testing.B) {
}
}
-func BenchmarkSwitchSyserror(b *testing.B) {
- globalError = linuxerr.EPERM
- j := 0
- for i := b.N; i > 0; i-- {
- switch globalError {
- case linuxerr.EACCES:
- j++
- case syserror.EINTR:
- j += 2
- case linuxerr.EAGAIN:
- j += 3
- }
- }
-}
-
func BenchmarkReturnUnix(b *testing.B) {
var localError error
f := func() error {
@@ -170,47 +139,40 @@ func BenchmarkConvertUnixLinuxerrZero(b *testing.B) {
}
type translationTestTable struct {
- fn string
errIn error
- syscallErrorIn unix.Errno
expectedBool bool
- expectedTranslation unix.Errno
+ expectedTranslation *gErrors.Error
}
func TestErrorTranslation(t *testing.T) {
- myError := errors.New("My test error")
- myError2 := errors.New("Another test error")
testTable := []translationTestTable{
- {"TranslateError", myError, 0, false, 0},
- {"TranslateError", myError2, 0, false, 0},
- {"AddErrorTranslation", myError, unix.EAGAIN, true, 0},
- {"AddErrorTranslation", myError, unix.EAGAIN, false, 0},
- {"AddErrorTranslation", myError, unix.EPERM, false, 0},
- {"TranslateError", myError, 0, true, unix.EAGAIN},
- {"TranslateError", myError2, 0, false, 0},
- {"AddErrorTranslation", myError2, unix.EPERM, true, 0},
- {"AddErrorTranslation", myError2, unix.EPERM, false, 0},
- {"AddErrorTranslation", myError2, unix.EAGAIN, false, 0},
- {"TranslateError", myError, 0, true, unix.EAGAIN},
- {"TranslateError", myError2, 0, true, unix.EPERM},
+ {
+ errIn: linuxerr.ENOENT,
+ },
+ {
+ errIn: unix.ENOENT,
+ },
+ {
+ errIn: linuxerr.ErrInterrupted,
+ expectedBool: true,
+ expectedTranslation: linuxerr.EINTR,
+ },
+ {
+ errIn: linuxerr.ERESTART_RESTARTBLOCK,
+ },
+ {
+ errIn: errors.New("some new error"),
+ },
}
for _, tt := range testTable {
- switch tt.fn {
- case "TranslateError":
- err, ok := syserror.TranslateError(tt.errIn)
- if ok != tt.expectedBool {
- t.Fatalf("%v(%v) => %v expected %v", tt.fn, tt.errIn, ok, tt.expectedBool)
+ t.Run(fmt.Sprintf("err: %v %T", tt.errIn, tt.errIn), func(t *testing.T) {
+ err, ok := linuxerr.TranslateError(tt.errIn)
+ if (!tt.expectedBool && err != nil) || (tt.expectedBool != ok) {
+ t.Fatalf("%v => %v %v expected %v err: nil", tt.errIn, err, ok, tt.expectedBool)
} else if err != tt.expectedTranslation {
- t.Fatalf("%v(%v) (error) => %v expected %v", tt.fn, tt.errIn, err, tt.expectedTranslation)
- }
- case "AddErrorTranslation":
- ok := syserror.AddErrorTranslation(tt.errIn, tt.syscallErrorIn)
- if ok != tt.expectedBool {
- t.Fatalf("%v(%v) => %v expected %v", tt.fn, tt.errIn, ok, tt.expectedBool)
+ t.Fatalf("%v => %v expected %v", tt.errIn, err, tt.expectedTranslation)
}
- default:
- t.Fatalf("Unknown function %v", tt.fn)
- }
+ })
}
}
diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD
index ad15d3672..56399232a 100644
--- a/pkg/eventchannel/BUILD
+++ b/pkg/eventchannel/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"event.go",
"event_any.go",
+ "processor.go",
"rate.go",
],
visibility = ["//:sandbox"],
diff --git a/pkg/eventchannel/event_any.go b/pkg/eventchannel/event_any.go
index 13f300061..b708937a4 100644
--- a/pkg/eventchannel/event_any.go
+++ b/pkg/eventchannel/event_any.go
@@ -26,3 +26,8 @@ import (
func newAny(m proto.Message) (*anypb.Any, error) {
return anypb.New(m)
}
+
+func emptyAny() *anypb.Any {
+ var any anypb.Any
+ return &any
+}
diff --git a/pkg/eventchannel/processor.go b/pkg/eventchannel/processor.go
new file mode 100644
index 000000000..e765c10d1
--- /dev/null
+++ b/pkg/eventchannel/processor.go
@@ -0,0 +1,130 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package eventchannel
+
+import (
+ "encoding/binary"
+ "fmt"
+ "io"
+ "os"
+ "time"
+
+ "google.golang.org/protobuf/proto"
+ pb "gvisor.dev/gvisor/pkg/eventchannel/eventchannel_go_proto"
+)
+
+// eventProcessor carries display state across multiple events.
+type eventProcessor struct {
+ filtering bool
+ // filtered is the number of events omitted since printing the last matching
+ // event. Only meaningful when filtering == true.
+ filtered uint64
+ // allowlist is the set of event names to display. If empty, all events are
+ // displayed.
+ allowlist map[string]bool
+}
+
+// newEventProcessor creates a new EventProcessor with filters.
+func newEventProcessor(filters []string) *eventProcessor {
+ e := &eventProcessor{
+ filtering: len(filters) > 0,
+ allowlist: make(map[string]bool),
+ }
+ for _, f := range filters {
+ e.allowlist[f] = true
+ }
+ return e
+}
+
+// processOne reads, parses and displays a single event from the event channel.
+//
+// The event channel is a stream of (msglen, payload) packets; this function
+// processes a single such packet. The msglen is a uvarint-encoded length for
+// the associated payload. The payload is a binary-encoded 'Any' protobuf, which
+// in turn encodes an arbitrary event protobuf.
+func (e *eventProcessor) processOne(src io.Reader, out *os.File) error {
+ // Read and parse the msglen.
+ lenbuf := make([]byte, binary.MaxVarintLen64)
+ if _, err := io.ReadFull(src, lenbuf); err != nil {
+ return err
+ }
+ msglen, consumed := binary.Uvarint(lenbuf)
+ if consumed <= 0 {
+ return fmt.Errorf("couldn't parse the message length")
+ }
+
+ // Read the payload.
+ buf := make([]byte, msglen)
+ // Copy any unused bytes from the len buffer into the payload buffer. These
+ // bytes are actually part of the payload.
+ extraBytes := copy(buf, lenbuf[consumed:])
+ if _, err := io.ReadFull(src, buf[extraBytes:]); err != nil {
+ return err
+ }
+
+ // Unmarshal the payload into an "Any" protobuf, which encodes the actual
+ // event.
+ encodedEv := emptyAny()
+ if err := proto.Unmarshal(buf, encodedEv); err != nil {
+ return fmt.Errorf("failed to unmarshal 'any' protobuf message: %v", err)
+ }
+
+ var ev pb.DebugEvent
+ if err := (encodedEv).UnmarshalTo(&ev); err != nil {
+ return fmt.Errorf("failed to decode 'any' protobuf message: %v", err)
+ }
+
+ if e.filtering && e.allowlist[ev.Name] {
+ e.filtered++
+ return nil
+ }
+
+ if e.filtering && e.filtered > 0 {
+ if e.filtered == 1 {
+ fmt.Fprintf(out, "... filtered %d event ...\n\n", e.filtered)
+ } else {
+ fmt.Fprintf(out, "... filtered %d events ...\n\n", e.filtered)
+ }
+ e.filtered = 0
+ }
+
+ // Extract the inner event and display it. Example:
+ //
+ // 2017-10-04 14:35:05.316180374 -0700 PDT m=+1.132485846
+ // cloud_gvisor.MemoryUsage {
+ // total: 23822336
+ // }
+ fmt.Fprintf(out, "%v\n%v {\n", time.Now(), ev.Name)
+ fmt.Fprintf(out, "%v", ev.Text)
+ fmt.Fprintf(out, "}\n\n")
+
+ return nil
+}
+
+// ProcessAll reads, parses and displays all events from src. The events are
+// displayed to out.
+func ProcessAll(src io.Reader, filters []string, out *os.File) error {
+ ep := newEventProcessor(filters)
+ for {
+ switch err := ep.processOne(src, out); err {
+ case nil:
+ continue
+ case io.EOF:
+ return nil
+ default:
+ return err
+ }
+ }
+}
diff --git a/pkg/eventfd/BUILD b/pkg/eventfd/BUILD
new file mode 100644
index 000000000..02407cb99
--- /dev/null
+++ b/pkg/eventfd/BUILD
@@ -0,0 +1,22 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "eventfd",
+ srcs = [
+ "eventfd.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/hostarch",
+ "//pkg/tcpip/link/rawfile",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "eventfd_test",
+ srcs = ["eventfd_test.go"],
+ library = ":eventfd",
+)
diff --git a/pkg/eventfd/eventfd.go b/pkg/eventfd/eventfd.go
new file mode 100644
index 000000000..acdac01b8
--- /dev/null
+++ b/pkg/eventfd/eventfd.go
@@ -0,0 +1,115 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package eventfd wraps Linux's eventfd(2) syscall.
+package eventfd
+
+import (
+ "fmt"
+ "io"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+)
+
+const sizeofUint64 = 8
+
+// Eventfd represents a Linux eventfd object.
+type Eventfd struct {
+ fd int
+}
+
+// Create returns an initialized eventfd.
+func Create() (Eventfd, error) {
+ fd, _, err := unix.RawSyscall(unix.SYS_EVENTFD2, 0, 0, 0)
+ if err != 0 {
+ return Eventfd{}, fmt.Errorf("failed to create eventfd: %v", error(err))
+ }
+ if err := unix.SetNonblock(int(fd), true); err != nil {
+ unix.Close(int(fd))
+ return Eventfd{}, err
+ }
+ return Eventfd{int(fd)}, nil
+}
+
+// Wrap returns an initialized Eventfd using the provided fd.
+func Wrap(fd int) Eventfd {
+ return Eventfd{fd}
+}
+
+// Close closes the eventfd, after which it should not be used.
+func (ev Eventfd) Close() error {
+ return unix.Close(ev.fd)
+}
+
+// Dup copies the eventfd, calling dup(2) on the underlying file descriptor.
+func (ev Eventfd) Dup() (Eventfd, error) {
+ other, err := unix.Dup(ev.fd)
+ if err != nil {
+ return Eventfd{}, fmt.Errorf("failed to dup: %v", other)
+ }
+ return Eventfd{other}, nil
+}
+
+// Notify alerts other users of the eventfd. Users can receive alerts by
+// calling Wait or Read.
+func (ev Eventfd) Notify() error {
+ return ev.Write(1)
+}
+
+// Write writes a specific value to the eventfd.
+func (ev Eventfd) Write(val uint64) error {
+ var buf [sizeofUint64]byte
+ hostarch.ByteOrder.PutUint64(buf[:], val)
+ for {
+ n, err := unix.Write(ev.fd, buf[:])
+ if err == unix.EINTR {
+ continue
+ }
+ if n != sizeofUint64 {
+ panic(fmt.Sprintf("short write to eventfd: got %d bytes, wanted %d", n, sizeofUint64))
+ }
+ return err
+ }
+}
+
+// Wait blocks until eventfd is non-zero (i.e. someone calls Notify or Write).
+func (ev Eventfd) Wait() error {
+ _, err := ev.Read()
+ return err
+}
+
+// Read blocks until eventfd is non-zero (i.e. someone calls Notify or Write)
+// and returns the value read.
+func (ev Eventfd) Read() (uint64, error) {
+ var tmp [sizeofUint64]byte
+ n, err := rawfile.BlockingReadUntranslated(ev.fd, tmp[:])
+ if err != 0 {
+ return 0, err
+ }
+ if n == 0 {
+ return 0, io.EOF
+ }
+ if n != sizeofUint64 {
+ panic(fmt.Sprintf("short read from eventfd: got %d bytes, wanted %d", n, sizeofUint64))
+ }
+ return hostarch.ByteOrder.Uint64(tmp[:]), nil
+}
+
+// FD returns the underlying file descriptor. Use with care, as this breaks the
+// Eventfd abstraction.
+func (ev Eventfd) FD() int {
+ return ev.fd
+}
diff --git a/pkg/eventfd/eventfd_test.go b/pkg/eventfd/eventfd_test.go
new file mode 100644
index 000000000..96998d530
--- /dev/null
+++ b/pkg/eventfd/eventfd_test.go
@@ -0,0 +1,75 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package eventfd
+
+import (
+ "testing"
+ "time"
+)
+
+func TestReadWrite(t *testing.T) {
+ efd, err := Create()
+ if err != nil {
+ t.Fatalf("failed to Create(): %v", err)
+ }
+ defer efd.Close()
+
+ // Make sure we can read actual values
+ const want = 343
+ if err := efd.Write(want); err != nil {
+ t.Fatalf("failed to write value: %d", want)
+ }
+
+ got, err := efd.Read()
+ if err != nil {
+ t.Fatalf("failed to read value: %v", err)
+ }
+ if got != want {
+ t.Fatalf("Read(): got %d, but wanted %d", got, want)
+ }
+}
+
+func TestWait(t *testing.T) {
+ efd, err := Create()
+ if err != nil {
+ t.Fatalf("failed to Create(): %v", err)
+ }
+ defer efd.Close()
+
+ // There's no way to test with certainty that Wait() blocks indefinitely, but
+ // as a best-effort we can wait a bit on it.
+ errCh := make(chan error)
+ go func() {
+ errCh <- efd.Wait()
+ }()
+ select {
+ case err := <-errCh:
+ t.Fatalf("Wait() returned without a call to Notify(): %v", err)
+ case <-time.After(500 * time.Millisecond):
+ }
+
+ // Notify and check that Wait() returned.
+ if err := efd.Notify(); err != nil {
+ t.Fatalf("Notify() failed: %v", err)
+ }
+ select {
+ case err := <-errCh:
+ if err != nil {
+ t.Fatalf("Read() failed: %v", err)
+ }
+ case <-time.After(5 * time.Second):
+ t.Fatalf("Read() did not return after Notify()")
+ }
+}
diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go
index 5d2ee4018..99410628f 100644
--- a/pkg/flipcall/ctrl_futex.go
+++ b/pkg/flipcall/ctrl_futex.go
@@ -121,7 +121,7 @@ func (ep *Endpoint) ctrlWaitFirst() error {
return ep.futexWaitUntilActive()
}
-func (ep *Endpoint) ctrlRoundTrip() error {
+func (ep *Endpoint) ctrlRoundTrip(mayRetainP bool) error {
if err := ep.enterFutexWait(); err != nil {
return err
}
@@ -133,6 +133,9 @@ func (ep *Endpoint) ctrlRoundTrip() error {
if err := ep.futexWakePeer(); err != nil {
return err
}
+ // Since we don't know if the peer Endpoint is in the same process as this
+ // one (in which case it may need our P to run), we allow our P to be
+ // retaken regardless of mayRetainP.
return ep.futexWaitUntilActive()
}
diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go
index f0e4ff487..88588ba0e 100644
--- a/pkg/flipcall/flipcall.go
+++ b/pkg/flipcall/flipcall.go
@@ -223,6 +223,23 @@ func (ep *Endpoint) RecvFirst() (uint32, error) {
// * If ep is a client Endpoint, ep.Connect() has previously been called and
// returned nil.
func (ep *Endpoint) SendRecv(dataLen uint32) (uint32, error) {
+ return ep.sendRecv(dataLen, false /* mayRetainP */)
+}
+
+// SendRecvFast is equivalent to SendRecv, but may prevent the caller's runtime
+// P from being released, in which case the calling goroutine continues to
+// count against GOMAXPROCS while waiting for the peer Endpoint to return
+// control to the caller.
+//
+// SendRecvFast is appropriate if the peer Endpoint is expected to consistently
+// return control in a short amount of time (less than ~10ms).
+//
+// Preconditions: As for SendRecv.
+func (ep *Endpoint) SendRecvFast(dataLen uint32) (uint32, error) {
+ return ep.sendRecv(dataLen, true /* mayRetainP */)
+}
+
+func (ep *Endpoint) sendRecv(dataLen uint32, mayRetainP bool) (uint32, error) {
if dataLen > ep.dataCap {
panic(fmt.Sprintf("attempting to send packet with datagram length %d (maximum %d)", dataLen, ep.dataCap))
}
@@ -233,7 +250,7 @@ func (ep *Endpoint) SendRecv(dataLen uint32) (uint32, error) {
// they can only shoot themselves in the foot.
*ep.dataLen() = dataLen
raceBecomeInactive()
- if err := ep.ctrlRoundTrip(); err != nil {
+ if err := ep.ctrlRoundTrip(mayRetainP); err != nil {
return 0, err
}
raceBecomeActive()
diff --git a/pkg/lisafs/BUILD b/pkg/lisafs/BUILD
new file mode 100644
index 000000000..313c1756d
--- /dev/null
+++ b/pkg/lisafs/BUILD
@@ -0,0 +1,117 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"],
+)
+
+go_template_instance(
+ name = "control_fd_refs",
+ out = "control_fd_refs.go",
+ package = "lisafs",
+ prefix = "controlFD",
+ template = "//pkg/refsvfs2:refs_template",
+ types = {
+ "T": "ControlFD",
+ },
+)
+
+go_template_instance(
+ name = "open_fd_refs",
+ out = "open_fd_refs.go",
+ package = "lisafs",
+ prefix = "openFD",
+ template = "//pkg/refsvfs2:refs_template",
+ types = {
+ "T": "OpenFD",
+ },
+)
+
+go_template_instance(
+ name = "control_fd_list",
+ out = "control_fd_list.go",
+ package = "lisafs",
+ prefix = "controlFD",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*ControlFD",
+ "Linker": "*ControlFD",
+ },
+)
+
+go_template_instance(
+ name = "open_fd_list",
+ out = "open_fd_list.go",
+ package = "lisafs",
+ prefix = "openFD",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*OpenFD",
+ "Linker": "*OpenFD",
+ },
+)
+
+go_library(
+ name = "lisafs",
+ srcs = [
+ "channel.go",
+ "client.go",
+ "client_file.go",
+ "communicator.go",
+ "connection.go",
+ "control_fd_list.go",
+ "control_fd_refs.go",
+ "fd.go",
+ "handlers.go",
+ "lisafs.go",
+ "message.go",
+ "open_fd_list.go",
+ "open_fd_refs.go",
+ "sample_message.go",
+ "server.go",
+ "sock.go",
+ ],
+ marshal = True,
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/cleanup",
+ "//pkg/context",
+ "//pkg/fdchannel",
+ "//pkg/flipcall",
+ "//pkg/fspath",
+ "//pkg/hostarch",
+ "//pkg/log",
+ "//pkg/marshal/primitive",
+ "//pkg/p9",
+ "//pkg/refsvfs2",
+ "//pkg/sync",
+ "//pkg/unet",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "sock_test",
+ size = "small",
+ srcs = ["sock_test.go"],
+ library = ":lisafs",
+ deps = [
+ "//pkg/marshal",
+ "//pkg/sync",
+ "//pkg/unet",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "connection_test",
+ size = "small",
+ srcs = ["connection_test.go"],
+ deps = [
+ ":lisafs",
+ "//pkg/sync",
+ "//pkg/unet",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/lisafs/README.md b/pkg/lisafs/README.md
index 51d0d40e5..6b857321a 100644
--- a/pkg/lisafs/README.md
+++ b/pkg/lisafs/README.md
@@ -1,5 +1,8 @@
# Replacing 9P
+NOTE: LISAFS is **NOT** production ready. There are still some security concerns
+that must be resolved first.
+
## Background
The Linux filesystem model consists of the following key aspects (modulo mounts,
diff --git a/pkg/lisafs/channel.go b/pkg/lisafs/channel.go
new file mode 100644
index 000000000..301212e51
--- /dev/null
+++ b/pkg/lisafs/channel.go
@@ -0,0 +1,190 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "math"
+ "runtime"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/fdchannel"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+var (
+ chanHeaderLen = uint32((*channelHeader)(nil).SizeBytes())
+)
+
+// maxChannels returns the number of channels a client can create.
+//
+// The server will reject channel creation requests beyond this (per client).
+// Note that we don't want the number of channels to be too large, because each
+// accounts for a large region of shared memory.
+// TODO(gvisor.dev/issue/6313): Tune the number of channels.
+func maxChannels() int {
+ maxChans := runtime.GOMAXPROCS(0)
+ if maxChans < 2 {
+ maxChans = 2
+ }
+ if maxChans > 4 {
+ maxChans = 4
+ }
+ return maxChans
+}
+
+// channel implements Communicator and represents the communication endpoint
+// for the client and server and is used to perform fast IPC. Apart from
+// communicating data, a channel is also capable of donating file descriptors.
+type channel struct {
+ fdTracker
+ dead bool
+ data flipcall.Endpoint
+ fdChan fdchannel.Endpoint
+}
+
+var _ Communicator = (*channel)(nil)
+
+// PayloadBuf implements Communicator.PayloadBuf.
+func (ch *channel) PayloadBuf(size uint32) []byte {
+ return ch.data.Data()[chanHeaderLen : chanHeaderLen+size]
+}
+
+// SndRcvMessage implements Communicator.SndRcvMessage.
+func (ch *channel) SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) {
+ // Write header. Requests can not donate FDs.
+ ch.marshalHdr(m, 0 /* numFDs */)
+
+ // One-shot communication. RPCs are expected to be quick rather than block.
+ rcvDataLen, err := ch.data.SendRecvFast(chanHeaderLen + payloadLen)
+ if err != nil {
+ // This channel is now unusable.
+ ch.dead = true
+ // Map the transport errors to EIO, but also log the real error.
+ log.Warningf("lisafs.sndRcvMessage: flipcall.Endpoint.SendRecv: %v", err)
+ return 0, 0, unix.EIO
+ }
+
+ return ch.rcvMsg(rcvDataLen)
+}
+
+func (ch *channel) shutdown() {
+ ch.data.Shutdown()
+}
+
+func (ch *channel) destroy() {
+ ch.dead = true
+ ch.fdChan.Destroy()
+ ch.data.Destroy()
+}
+
+// createChannel creates a server side channel. It returns a packet window
+// descriptor (for the data channel) and an open socket for the FD channel.
+func (c *Connection) createChannel(maxMessageSize uint32) (*channel, flipcall.PacketWindowDescriptor, int, error) {
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+ // If c.channels is nil, the connection has closed.
+ if c.channels == nil || len(c.channels) >= maxChannels() {
+ return nil, flipcall.PacketWindowDescriptor{}, -1, unix.ENOSYS
+ }
+ ch := &channel{}
+
+ // Set up data channel.
+ desc, err := c.channelAlloc.Allocate(flipcall.PacketHeaderBytes + int(chanHeaderLen+maxMessageSize))
+ if err != nil {
+ return nil, flipcall.PacketWindowDescriptor{}, -1, err
+ }
+ if err := ch.data.Init(flipcall.ServerSide, desc); err != nil {
+ return nil, flipcall.PacketWindowDescriptor{}, -1, err
+ }
+
+ // Set up FD channel.
+ fdSocks, err := fdchannel.NewConnectedSockets()
+ if err != nil {
+ ch.data.Destroy()
+ return nil, flipcall.PacketWindowDescriptor{}, -1, err
+ }
+ ch.fdChan.Init(fdSocks[0])
+ clientFDSock := fdSocks[1]
+
+ c.channels = append(c.channels, ch)
+ return ch, desc, clientFDSock, nil
+}
+
+// sendFDs sends as many FDs as it can. The failure to send an FD does not
+// cause an error and fail the entire RPC. FDs are considered supplementary
+// responses that are not critical to the RPC response itself. The failure to
+// send the (i)th FD will cause all the following FDs to not be sent as well
+// because the order in which FDs are donated is important.
+func (ch *channel) sendFDs(fds []int) uint8 {
+ numFDs := len(fds)
+ if numFDs == 0 {
+ return 0
+ }
+
+ if numFDs > math.MaxUint8 {
+ log.Warningf("dropping all FDs because too many FDs to donate: %v", numFDs)
+ return 0
+ }
+
+ for i, fd := range fds {
+ if err := ch.fdChan.SendFD(fd); err != nil {
+ log.Warningf("error occurred while sending (%d/%d)th FD on channel(%p): %v", i+1, numFDs, ch, err)
+ return uint8(i)
+ }
+ }
+ return uint8(numFDs)
+}
+
+// channelHeader is the header present in front of each message received on
+// flipcall endpoint when the protocol version being used is 1.
+//
+// +marshal
+type channelHeader struct {
+ message MID
+ numFDs uint8
+ _ uint8 // Need to make struct packed.
+}
+
+func (ch *channel) marshalHdr(m MID, numFDs uint8) {
+ header := &channelHeader{
+ message: m,
+ numFDs: numFDs,
+ }
+ header.MarshalUnsafe(ch.data.Data())
+}
+
+func (ch *channel) rcvMsg(dataLen uint32) (MID, uint32, error) {
+ if dataLen < chanHeaderLen {
+ log.Warningf("received data has size smaller than header length: %d", dataLen)
+ return 0, 0, unix.EIO
+ }
+
+ // Read header first.
+ var header channelHeader
+ header.UnmarshalUnsafe(ch.data.Data())
+
+ // Read any FDs.
+ for i := 0; i < int(header.numFDs); i++ {
+ fd, err := ch.fdChan.RecvFDNonblock()
+ if err != nil {
+ log.Warningf("expected %d FDs, received %d successfully, got err after that: %v", header.numFDs, i, err)
+ break
+ }
+ ch.TrackFD(fd)
+ }
+
+ return header.message, dataLen - chanHeaderLen, nil
+}
diff --git a/pkg/lisafs/client.go b/pkg/lisafs/client.go
new file mode 100644
index 000000000..ccf1b9f72
--- /dev/null
+++ b/pkg/lisafs/client.go
@@ -0,0 +1,432 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "fmt"
+ "math"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+const (
+ // fdsToCloseBatchSize is the number of closed FDs batched before an Close
+ // RPC is made to close them all. fdsToCloseBatchSize is immutable.
+ fdsToCloseBatchSize = 100
+)
+
+// Client helps manage a connection to the lisafs server and pass messages
+// efficiently. There is a 1:1 mapping between a Connection and a Client.
+type Client struct {
+ // sockComm is the main socket by which this connections is established.
+ // Communication over the socket is synchronized by sockMu.
+ sockMu sync.Mutex
+ sockComm *sockCommunicator
+
+ // channelsMu protects channels and availableChannels.
+ channelsMu sync.Mutex
+ // channels tracks all the channels.
+ channels []*channel
+ // availableChannels is a LIFO (stack) of channels available to be used.
+ availableChannels []*channel
+ // activeWg represents active channels.
+ activeWg sync.WaitGroup
+
+ // watchdogWg only holds the watchdog goroutine.
+ watchdogWg sync.WaitGroup
+
+ // supported caches information about which messages are supported. It is
+ // indexed by MID. An MID is supported if supported[MID] is true.
+ supported []bool
+
+ // maxMessageSize is the maximum payload length (in bytes) that can be sent.
+ // It is initialized on Mount and is immutable.
+ maxMessageSize uint32
+
+ // fdsToClose tracks the FDs to close. It caches the FDs no longer being used
+ // by the client and closes them in one shot. It is not preserved across
+ // checkpoint/restore as FDIDs are not preserved.
+ fdsMu sync.Mutex
+ fdsToClose []FDID
+}
+
+// NewClient creates a new client for communication with the server. It mounts
+// the server and creates channels for fast IPC. NewClient takes ownership over
+// the passed socket. On success, it returns the initialized client along with
+// the root Inode.
+func NewClient(sock *unet.Socket, mountPath string) (*Client, *Inode, error) {
+ maxChans := maxChannels()
+ c := &Client{
+ sockComm: newSockComm(sock),
+ channels: make([]*channel, 0, maxChans),
+ availableChannels: make([]*channel, 0, maxChans),
+ maxMessageSize: 1 << 20, // 1 MB for now.
+ fdsToClose: make([]FDID, 0, fdsToCloseBatchSize),
+ }
+
+ // Start a goroutine to check socket health. This goroutine is also
+ // responsible for client cleanup.
+ c.watchdogWg.Add(1)
+ go c.watchdog()
+
+ // Clean everything up if anything fails.
+ cu := cleanup.Make(func() {
+ c.Close()
+ })
+ defer cu.Clean()
+
+ // Mount the server first. Assume Mount is supported so that we can make the
+ // Mount RPC below.
+ c.supported = make([]bool, Mount+1)
+ c.supported[Mount] = true
+ mountMsg := MountReq{
+ MountPath: SizedString(mountPath),
+ }
+ var mountResp MountResp
+ if err := c.SndRcvMessage(Mount, uint32(mountMsg.SizeBytes()), mountMsg.MarshalBytes, mountResp.UnmarshalBytes, nil); err != nil {
+ return nil, nil, err
+ }
+
+ // Initialize client.
+ c.maxMessageSize = uint32(mountResp.MaxMessageSize)
+ var maxSuppMID MID
+ for _, suppMID := range mountResp.SupportedMs {
+ if suppMID > maxSuppMID {
+ maxSuppMID = suppMID
+ }
+ }
+ c.supported = make([]bool, maxSuppMID+1)
+ for _, suppMID := range mountResp.SupportedMs {
+ c.supported[suppMID] = true
+ }
+
+ // Create channels parallely so that channels can be used to create more
+ // channels and costly initialization like flipcall.Endpoint.Connect can
+ // proceed parallely.
+ var channelsWg sync.WaitGroup
+ channelErrs := make([]error, maxChans)
+ for i := 0; i < maxChans; i++ {
+ channelsWg.Add(1)
+ curChanID := i
+ go func() {
+ defer channelsWg.Done()
+ ch, err := c.createChannel()
+ if err != nil {
+ log.Warningf("channel creation failed: %v", err)
+ channelErrs[curChanID] = err
+ return
+ }
+ c.channelsMu.Lock()
+ c.channels = append(c.channels, ch)
+ c.availableChannels = append(c.availableChannels, ch)
+ c.channelsMu.Unlock()
+ }()
+ }
+ channelsWg.Wait()
+
+ for _, channelErr := range channelErrs {
+ // Return the first non-nil channel creation error.
+ if channelErr != nil {
+ return nil, nil, channelErr
+ }
+ }
+ cu.Release()
+
+ return c, &mountResp.Root, nil
+}
+
+func (c *Client) watchdog() {
+ defer c.watchdogWg.Done()
+
+ events := []unix.PollFd{
+ {
+ Fd: int32(c.sockComm.FD()),
+ Events: unix.POLLHUP | unix.POLLRDHUP,
+ },
+ }
+
+ // Wait for a shutdown event.
+ for {
+ n, err := unix.Ppoll(events, nil, nil)
+ if err == unix.EINTR || err == unix.EAGAIN {
+ continue
+ }
+ if err != nil {
+ log.Warningf("lisafs.Client.watch(): %v", err)
+ } else if n != 1 {
+ log.Warningf("lisafs.Client.watch(): got %d events, wanted 1", n)
+ }
+ break
+ }
+
+ // Shutdown all active channels and wait for them to complete.
+ c.shutdownActiveChans()
+ c.activeWg.Wait()
+
+ // Close all channels.
+ c.channelsMu.Lock()
+ for _, ch := range c.channels {
+ ch.destroy()
+ }
+ c.channelsMu.Unlock()
+
+ // Close main socket.
+ c.sockComm.destroy()
+}
+
+func (c *Client) shutdownActiveChans() {
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+
+ availableChans := make(map[*channel]bool)
+ for _, ch := range c.availableChannels {
+ availableChans[ch] = true
+ }
+ for _, ch := range c.channels {
+ // A channel that is not available is active.
+ if _, ok := availableChans[ch]; !ok {
+ log.Debugf("shutting down active channel@%p...", ch)
+ ch.shutdown()
+ }
+ }
+
+ // Prevent channels from becoming available and serving new requests.
+ c.availableChannels = nil
+}
+
+// Close shuts down the main socket and waits for the watchdog to clean up.
+func (c *Client) Close() {
+ // This shutdown has no effect if the watchdog has already fired and closed
+ // the main socket.
+ c.sockComm.shutdown()
+ c.watchdogWg.Wait()
+}
+
+func (c *Client) createChannel() (*channel, error) {
+ var chanResp ChannelResp
+ var fds [2]int
+ if err := c.SndRcvMessage(Channel, 0, NoopMarshal, chanResp.UnmarshalUnsafe, fds[:]); err != nil {
+ return nil, err
+ }
+ if fds[0] < 0 || fds[1] < 0 {
+ closeFDs(fds[:])
+ return nil, fmt.Errorf("insufficient FDs provided in Channel response: %v", fds)
+ }
+
+ // Lets create the channel.
+ defer closeFDs(fds[:1]) // The data FD is not needed after this.
+ desc := flipcall.PacketWindowDescriptor{
+ FD: fds[0],
+ Offset: chanResp.dataOffset,
+ Length: int(chanResp.dataLength),
+ }
+
+ ch := &channel{}
+ if err := ch.data.Init(flipcall.ClientSide, desc); err != nil {
+ closeFDs(fds[1:])
+ return nil, err
+ }
+ ch.fdChan.Init(fds[1]) // fdChan now owns this FD.
+
+ // Only a connected channel is usable.
+ if err := ch.data.Connect(); err != nil {
+ ch.destroy()
+ return nil, err
+ }
+ return ch, nil
+}
+
+// IsSupported returns true if this connection supports the passed message.
+func (c *Client) IsSupported(m MID) bool {
+ return int(m) < len(c.supported) && c.supported[m]
+}
+
+// CloseFDBatched either queues the passed FD to be closed or makes a batch
+// RPC to close all the accumulated FDs-to-close.
+func (c *Client) CloseFDBatched(ctx context.Context, fd FDID) {
+ c.fdsMu.Lock()
+ c.fdsToClose = append(c.fdsToClose, fd)
+ if len(c.fdsToClose) < fdsToCloseBatchSize {
+ c.fdsMu.Unlock()
+ return
+ }
+
+ // Flush the cache. We should not hold fdsMu while making an RPC, so be sure
+ // to copy the fdsToClose to another buffer before unlocking fdsMu.
+ var toCloseArr [fdsToCloseBatchSize]FDID
+ toClose := toCloseArr[:len(c.fdsToClose)]
+ copy(toClose, c.fdsToClose)
+
+ // Clear fdsToClose so other FDIDs can be appended.
+ c.fdsToClose = c.fdsToClose[:0]
+ c.fdsMu.Unlock()
+
+ req := CloseReq{FDs: toClose}
+ ctx.UninterruptibleSleepStart(false)
+ err := c.SndRcvMessage(Close, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ log.Warningf("lisafs: batch closing FDs returned error: %v", err)
+ }
+}
+
+// SyncFDs makes a Fsync RPC to sync multiple FDs.
+func (c *Client) SyncFDs(ctx context.Context, fds []FDID) error {
+ if len(fds) == 0 {
+ return nil
+ }
+ req := FsyncReq{FDs: fds}
+ ctx.UninterruptibleSleepStart(false)
+ err := c.SndRcvMessage(FSync, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// SndRcvMessage invokes reqMarshal to marshal the request onto the payload
+// buffer, wakes up the server to process the request, waits for the response
+// and invokes respUnmarshal with the response payload. respFDs is populated
+// with the received FDs, extra fields are set to -1.
+//
+// Note that the function arguments intentionally accept marshal.Marshallable
+// functions like Marshal{Bytes/Unsafe} and Unmarshal{Bytes/Unsafe} instead of
+// directly accepting the marshal.Marshallable interface. Even though just
+// accepting marshal.Marshallable is cleaner, it leads to a heap allocation
+// (even if that interface variable itself does not escape). In other words,
+// implicit conversion to an interface leads to an allocation.
+//
+// Precondition: reqMarshal and respUnmarshal must be non-nil.
+func (c *Client) SndRcvMessage(m MID, payloadLen uint32, reqMarshal func(dst []byte), respUnmarshal func(src []byte), respFDs []int) error {
+ if !c.IsSupported(m) {
+ return unix.EOPNOTSUPP
+ }
+ if payloadLen > c.maxMessageSize {
+ log.Warningf("message %d has message size = %d which is larger than client.maxMessageSize = %d", m, payloadLen, c.maxMessageSize)
+ return unix.EIO
+ }
+ wantFDs := len(respFDs)
+ if wantFDs > math.MaxUint8 {
+ log.Warningf("want too many FDs: %d", wantFDs)
+ return unix.EINVAL
+ }
+
+ // Acquire a communicator.
+ comm := c.acquireCommunicator()
+ defer c.releaseCommunicator(comm)
+
+ // Marshal the request into comm's payload buffer and make the RPC.
+ reqMarshal(comm.PayloadBuf(payloadLen))
+ respM, respPayloadLen, err := comm.SndRcvMessage(m, payloadLen, uint8(wantFDs))
+
+ // Handle FD donation.
+ rcvFDs := comm.ReleaseFDs()
+ if numRcvFDs := len(rcvFDs); numRcvFDs+wantFDs > 0 {
+ // releasedFDs is memory owned by comm which can not be returned to caller.
+ // Copy it into the caller's buffer.
+ numFDCopied := copy(respFDs, rcvFDs)
+ if numFDCopied < numRcvFDs {
+ log.Warningf("%d unexpected FDs were donated by the server, wanted", numRcvFDs-numFDCopied, wantFDs)
+ closeFDs(rcvFDs[numFDCopied:])
+ }
+ if numFDCopied < wantFDs {
+ for i := numFDCopied; i < wantFDs; i++ {
+ respFDs[i] = -1
+ }
+ }
+ }
+
+ // Error cases.
+ if err != nil {
+ closeFDs(respFDs)
+ return err
+ }
+ if respM == Error {
+ closeFDs(respFDs)
+ var resp ErrorResp
+ resp.UnmarshalUnsafe(comm.PayloadBuf(respPayloadLen))
+ return unix.Errno(resp.errno)
+ }
+ if respM != m {
+ closeFDs(respFDs)
+ log.Warningf("sent %d message but got %d in response", m, respM)
+ return unix.EINVAL
+ }
+
+ // Success. The payload must be unmarshalled *before* comm is released.
+ respUnmarshal(comm.PayloadBuf(respPayloadLen))
+ return nil
+}
+
+// Postcondition: releaseCommunicator() must be called on the returned value.
+func (c *Client) acquireCommunicator() Communicator {
+ // Prefer using channel over socket because:
+ // - Channel uses a shared memory region for passing messages. IO from shared
+ // memory is faster and does not involve making a syscall.
+ // - No intermediate buffer allocation needed. With a channel, the message
+ // can be directly pasted into the shared memory region.
+ if ch := c.getChannel(); ch != nil {
+ return ch
+ }
+
+ c.sockMu.Lock()
+ return c.sockComm
+}
+
+// Precondition: comm must have been acquired via acquireCommunicator().
+func (c *Client) releaseCommunicator(comm Communicator) {
+ switch t := comm.(type) {
+ case *sockCommunicator:
+ c.sockMu.Unlock() // +checklocksforce: locked in acquireCommunicator().
+ case *channel:
+ c.releaseChannel(t)
+ default:
+ panic(fmt.Sprintf("unknown communicator type %T", t))
+ }
+}
+
+// getChannel pops a channel from the available channels stack. The caller must
+// release the channel after use.
+func (c *Client) getChannel() *channel {
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+ if len(c.availableChannels) == 0 {
+ return nil
+ }
+
+ idx := len(c.availableChannels) - 1
+ ch := c.availableChannels[idx]
+ c.availableChannels = c.availableChannels[:idx]
+ c.activeWg.Add(1)
+ return ch
+}
+
+// releaseChannel pushes the passed channel onto the available channel stack if
+// reinsert is true.
+func (c *Client) releaseChannel(ch *channel) {
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+
+ // If availableChannels is nil, then watchdog has fired and the client is
+ // shutting down. So don't make this channel available again.
+ if !ch.dead && c.availableChannels != nil {
+ c.availableChannels = append(c.availableChannels, ch)
+ }
+ c.activeWg.Done()
+}
diff --git a/pkg/lisafs/client_file.go b/pkg/lisafs/client_file.go
new file mode 100644
index 000000000..170c15705
--- /dev/null
+++ b/pkg/lisafs/client_file.go
@@ -0,0 +1,528 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "fmt"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
+)
+
+// ClientFD is a wrapper around FDID that provides client-side utilities
+// so that RPC making is easier.
+type ClientFD struct {
+ fd FDID
+ client *Client
+}
+
+// ID returns the underlying FDID.
+func (f *ClientFD) ID() FDID {
+ return f.fd
+}
+
+// Client returns the backing Client.
+func (f *ClientFD) Client() *Client {
+ return f.client
+}
+
+// NewFD initializes a new ClientFD.
+func (c *Client) NewFD(fd FDID) ClientFD {
+ return ClientFD{
+ client: c,
+ fd: fd,
+ }
+}
+
+// Ok returns true if the underlying FD is ok.
+func (f *ClientFD) Ok() bool {
+ return f.fd.Ok()
+}
+
+// CloseBatched queues this FD to be closed on the server and resets f.fd.
+// This maybe invoke the Close RPC if the queue is full.
+func (f *ClientFD) CloseBatched(ctx context.Context) {
+ f.client.CloseFDBatched(ctx, f.fd)
+ f.fd = InvalidFDID
+}
+
+// Close closes this FD immediately (invoking a Close RPC). Consider using
+// CloseBatched if closing this FD on remote right away is not critical.
+func (f *ClientFD) Close(ctx context.Context) error {
+ fdArr := [1]FDID{f.fd}
+ req := CloseReq{FDs: fdArr[:]}
+
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(Close, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// OpenAt makes the OpenAt RPC.
+func (f *ClientFD) OpenAt(ctx context.Context, flags uint32) (FDID, int, error) {
+ req := OpenAtReq{
+ FD: f.fd,
+ Flags: flags,
+ }
+ var respFD [1]int
+ var resp OpenAtResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(OpenAt, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalUnsafe, respFD[:])
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.NewFD, respFD[0], err
+}
+
+// OpenCreateAt makes the OpenCreateAt RPC.
+func (f *ClientFD) OpenCreateAt(ctx context.Context, name string, flags uint32, mode linux.FileMode, uid UID, gid GID) (Inode, FDID, int, error) {
+ var req OpenCreateAtReq
+ req.DirFD = f.fd
+ req.Name = SizedString(name)
+ req.Flags = primitive.Uint32(flags)
+ req.Mode = mode
+ req.UID = uid
+ req.GID = gid
+
+ var respFD [1]int
+ var resp OpenCreateAtResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(OpenCreateAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, respFD[:])
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.Child, resp.NewFD, respFD[0], err
+}
+
+// StatTo makes the Fstat RPC and populates stat with the result.
+func (f *ClientFD) StatTo(ctx context.Context, stat *linux.Statx) error {
+ req := StatReq{FD: f.fd}
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FStat, uint32(req.SizeBytes()), req.MarshalUnsafe, stat.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// Sync makes the Fsync RPC.
+func (f *ClientFD) Sync(ctx context.Context) error {
+ req := FsyncReq{FDs: []FDID{f.fd}}
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FSync, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// chunkify applies fn to buf in chunks based on chunkSize.
+func chunkify(chunkSize uint64, buf []byte, fn func([]byte, uint64) (uint64, error)) (uint64, error) {
+ toProcess := uint64(len(buf))
+ var (
+ totalProcessed uint64
+ curProcessed uint64
+ off uint64
+ err error
+ )
+ for {
+ if totalProcessed == toProcess {
+ return totalProcessed, nil
+ }
+
+ if totalProcessed+chunkSize > toProcess {
+ curProcessed, err = fn(buf[totalProcessed:], off)
+ } else {
+ curProcessed, err = fn(buf[totalProcessed:totalProcessed+chunkSize], off)
+ }
+ totalProcessed += curProcessed
+ off += curProcessed
+
+ if err != nil {
+ return totalProcessed, err
+ }
+
+ // Return partial result immediately.
+ if curProcessed < chunkSize {
+ return totalProcessed, nil
+ }
+
+ // If we received more bytes than we ever requested, this is a problem.
+ if totalProcessed > toProcess {
+ panic(fmt.Sprintf("bytes completed (%d)) > requested (%d)", totalProcessed, toProcess))
+ }
+ }
+}
+
+// Read makes the PRead RPC.
+func (f *ClientFD) Read(ctx context.Context, dst []byte, offset uint64) (uint64, error) {
+ var resp PReadResp
+ // maxDataReadSize represents the maximum amount of data we can read at once
+ // (maximum message size - metadata size present in resp). Uninitialized
+ // resp.SizeBytes() correctly returns the metadata size only (since the read
+ // buffer is empty).
+ maxDataReadSize := uint64(f.client.maxMessageSize) - uint64(resp.SizeBytes())
+ return chunkify(maxDataReadSize, dst, func(buf []byte, curOff uint64) (uint64, error) {
+ req := PReadReq{
+ Offset: offset + curOff,
+ FD: f.fd,
+ Count: uint32(len(buf)),
+ }
+
+ // This will be unmarshalled into. Already set Buf so that we don't need to
+ // allocate a temporary buffer during unmarshalling.
+ // PReadResp.UnmarshalBytes expects this to be set.
+ resp.Buf = buf
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(PRead, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return uint64(resp.NumBytes), err
+ })
+}
+
+// Write makes the PWrite RPC.
+func (f *ClientFD) Write(ctx context.Context, src []byte, offset uint64) (uint64, error) {
+ var req PWriteReq
+ // maxDataWriteSize represents the maximum amount of data we can write at
+ // once (maximum message size - metadata size present in req). Uninitialized
+ // req.SizeBytes() correctly returns the metadata size only (since the write
+ // buffer is empty).
+ maxDataWriteSize := uint64(f.client.maxMessageSize) - uint64(req.SizeBytes())
+ return chunkify(maxDataWriteSize, src, func(buf []byte, curOff uint64) (uint64, error) {
+ req = PWriteReq{
+ Offset: primitive.Uint64(offset + curOff),
+ FD: f.fd,
+ NumBytes: primitive.Uint32(len(buf)),
+ Buf: buf,
+ }
+
+ var resp PWriteResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(PWrite, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.Count, err
+ })
+}
+
+// MkdirAt makes the MkdirAt RPC.
+func (f *ClientFD) MkdirAt(ctx context.Context, name string, mode linux.FileMode, uid UID, gid GID) (*Inode, error) {
+ var req MkdirAtReq
+ req.DirFD = f.fd
+ req.Name = SizedString(name)
+ req.Mode = mode
+ req.UID = uid
+ req.GID = gid
+
+ var resp MkdirAtResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(MkdirAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return &resp.ChildDir, err
+}
+
+// SymlinkAt makes the SymlinkAt RPC.
+func (f *ClientFD) SymlinkAt(ctx context.Context, name, target string, uid UID, gid GID) (*Inode, error) {
+ req := SymlinkAtReq{
+ DirFD: f.fd,
+ Name: SizedString(name),
+ Target: SizedString(target),
+ UID: uid,
+ GID: gid,
+ }
+
+ var resp SymlinkAtResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(SymlinkAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return &resp.Symlink, err
+}
+
+// LinkAt makes the LinkAt RPC.
+func (f *ClientFD) LinkAt(ctx context.Context, targetFD FDID, name string) (*Inode, error) {
+ req := LinkAtReq{
+ DirFD: f.fd,
+ Target: targetFD,
+ Name: SizedString(name),
+ }
+
+ var resp LinkAtResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(LinkAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return &resp.Link, err
+}
+
+// MknodAt makes the MknodAt RPC.
+func (f *ClientFD) MknodAt(ctx context.Context, name string, mode linux.FileMode, uid UID, gid GID, minor, major uint32) (*Inode, error) {
+ var req MknodAtReq
+ req.DirFD = f.fd
+ req.Name = SizedString(name)
+ req.Mode = mode
+ req.UID = uid
+ req.GID = gid
+ req.Minor = primitive.Uint32(minor)
+ req.Major = primitive.Uint32(major)
+
+ var resp MknodAtResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(MknodAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return &resp.Child, err
+}
+
+// SetStat makes the SetStat RPC.
+func (f *ClientFD) SetStat(ctx context.Context, stat *linux.Statx) (uint32, error, error) {
+ req := SetStatReq{
+ FD: f.fd,
+ Mask: stat.Mask,
+ Mode: uint32(stat.Mode),
+ UID: UID(stat.UID),
+ GID: GID(stat.GID),
+ Size: stat.Size,
+ Atime: linux.Timespec{
+ Sec: stat.Atime.Sec,
+ Nsec: int64(stat.Atime.Nsec),
+ },
+ Mtime: linux.Timespec{
+ Sec: stat.Mtime.Sec,
+ Nsec: int64(stat.Mtime.Nsec),
+ },
+ }
+
+ var resp SetStatResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(SetStat, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.FailureMask, unix.Errno(resp.FailureErrNo), err
+}
+
+// WalkMultiple makes the Walk RPC with multiple path components.
+func (f *ClientFD) WalkMultiple(ctx context.Context, names []string) (WalkStatus, []Inode, error) {
+ req := WalkReq{
+ DirFD: f.fd,
+ Path: StringArray(names),
+ }
+
+ var resp WalkResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(Walk, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.Status, resp.Inodes, err
+}
+
+// Walk makes the Walk RPC with just one path component to walk.
+func (f *ClientFD) Walk(ctx context.Context, name string) (*Inode, error) {
+ req := WalkReq{
+ DirFD: f.fd,
+ Path: []string{name},
+ }
+
+ var inode [1]Inode
+ resp := WalkResp{Inodes: inode[:]}
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(Walk, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ return nil, err
+ }
+
+ switch resp.Status {
+ case WalkComponentDoesNotExist:
+ return nil, unix.ENOENT
+ case WalkComponentSymlink:
+ // f is not a directory which can be walked on.
+ return nil, unix.ENOTDIR
+ }
+
+ if n := len(resp.Inodes); n > 1 {
+ for i := range resp.Inodes {
+ f.client.CloseFDBatched(ctx, resp.Inodes[i].ControlFD)
+ }
+ log.Warningf("requested to walk one component, but got %d results", n)
+ return nil, unix.EIO
+ } else if n == 0 {
+ log.Warningf("walk has success status but no results returned")
+ return nil, unix.ENOENT
+ }
+ return &inode[0], err
+}
+
+// WalkStat makes the WalkStat RPC with multiple path components to walk.
+func (f *ClientFD) WalkStat(ctx context.Context, names []string) ([]linux.Statx, error) {
+ req := WalkReq{
+ DirFD: f.fd,
+ Path: StringArray(names),
+ }
+
+ var resp WalkStatResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(WalkStat, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.Stats, err
+}
+
+// StatFSTo makes the FStatFS RPC and populates statFS with the result.
+func (f *ClientFD) StatFSTo(ctx context.Context, statFS *StatFS) error {
+ req := FStatFSReq{FD: f.fd}
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FStatFS, uint32(req.SizeBytes()), req.MarshalUnsafe, statFS.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// Allocate makes the FAllocate RPC.
+func (f *ClientFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ req := FAllocateReq{
+ FD: f.fd,
+ Mode: mode,
+ Offset: offset,
+ Length: length,
+ }
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FAllocate, uint32(req.SizeBytes()), req.MarshalUnsafe, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// ReadLinkAt makes the ReadLinkAt RPC.
+func (f *ClientFD) ReadLinkAt(ctx context.Context) (string, error) {
+ req := ReadLinkAtReq{FD: f.fd}
+ var resp ReadLinkAtResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(ReadLinkAt, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return string(resp.Target), err
+}
+
+// Flush makes the Flush RPC.
+func (f *ClientFD) Flush(ctx context.Context) error {
+ if !f.client.IsSupported(Flush) {
+ // If Flush is not supported, it probably means that it would be a noop.
+ return nil
+ }
+ req := FlushReq{FD: f.fd}
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(Flush, uint32(req.SizeBytes()), req.MarshalUnsafe, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// Connect makes the Connect RPC.
+func (f *ClientFD) Connect(ctx context.Context, sockType linux.SockType) (int, error) {
+ req := ConnectReq{FD: f.fd, SockType: uint32(sockType)}
+ var sockFD [1]int
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(Connect, uint32(req.SizeBytes()), req.MarshalUnsafe, NoopUnmarshal, sockFD[:])
+ ctx.UninterruptibleSleepFinish(false)
+ if err == nil && sockFD[0] < 0 {
+ err = unix.EBADF
+ }
+ return sockFD[0], err
+}
+
+// UnlinkAt makes the UnlinkAt RPC.
+func (f *ClientFD) UnlinkAt(ctx context.Context, name string, flags uint32) error {
+ req := UnlinkAtReq{
+ DirFD: f.fd,
+ Name: SizedString(name),
+ Flags: primitive.Uint32(flags),
+ }
+
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(UnlinkAt, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// RenameTo makes the RenameAt RPC which renames f to newDirFD directory with
+// name newName.
+func (f *ClientFD) RenameTo(ctx context.Context, newDirFD FDID, newName string) error {
+ req := RenameAtReq{
+ Renamed: f.fd,
+ NewDir: newDirFD,
+ NewName: SizedString(newName),
+ }
+
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(RenameAt, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// Getdents64 makes the Getdents64 RPC.
+func (f *ClientFD) Getdents64(ctx context.Context, count int32) ([]Dirent64, error) {
+ req := Getdents64Req{
+ DirFD: f.fd,
+ Count: count,
+ }
+
+ var resp Getdents64Resp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(Getdents64, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.Dirents, err
+}
+
+// ListXattr makes the FListXattr RPC.
+func (f *ClientFD) ListXattr(ctx context.Context, size uint64) ([]string, error) {
+ req := FListXattrReq{
+ FD: f.fd,
+ Size: size,
+ }
+
+ var resp FListXattrResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FListXattr, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.Xattrs, err
+}
+
+// GetXattr makes the FGetXattr RPC.
+func (f *ClientFD) GetXattr(ctx context.Context, name string, size uint64) (string, error) {
+ req := FGetXattrReq{
+ FD: f.fd,
+ Name: SizedString(name),
+ BufSize: primitive.Uint32(size),
+ }
+
+ var resp FGetXattrResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FGetXattr, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return string(resp.Value), err
+}
+
+// SetXattr makes the FSetXattr RPC.
+func (f *ClientFD) SetXattr(ctx context.Context, name string, value string, flags uint32) error {
+ req := FSetXattrReq{
+ FD: f.fd,
+ Name: SizedString(name),
+ Value: SizedString(value),
+ Flags: primitive.Uint32(flags),
+ }
+
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FSetXattr, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// RemoveXattr makes the FRemoveXattr RPC.
+func (f *ClientFD) RemoveXattr(ctx context.Context, name string) error {
+ req := FRemoveXattrReq{
+ FD: f.fd,
+ Name: SizedString(name),
+ }
+
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FRemoveXattr, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
diff --git a/pkg/lisafs/communicator.go b/pkg/lisafs/communicator.go
new file mode 100644
index 000000000..ec2035158
--- /dev/null
+++ b/pkg/lisafs/communicator.go
@@ -0,0 +1,80 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import "golang.org/x/sys/unix"
+
+// Communicator is a server side utility which represents exactly how the
+// server is communicating with the client.
+type Communicator interface {
+ // PayloadBuf returns a slice to the payload section of its internal buffer
+ // where the message can be marshalled. The handlers should use this to
+ // populate the payload buffer with the message.
+ //
+ // The payload buffer contents *should* be preserved across calls with
+ // different sizes. Note that this is not a guarantee, because a compromised
+ // owner of a "shared" payload buffer can tamper with its contents anytime,
+ // even when it's not its turn to do so.
+ PayloadBuf(size uint32) []byte
+
+ // SndRcvMessage sends message m. The caller must have populated PayloadBuf()
+ // with payloadLen bytes. The caller expects to receive wantFDs FDs.
+ // Any received FDs must be accessible via ReleaseFDs(). It returns the
+ // response message along with the response payload length.
+ SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error)
+
+ // DonateFD makes fd non-blocking and starts tracking it. The next call to
+ // ReleaseFDs will include fd in the order it was added. Communicator takes
+ // ownership of fd. Server side should call this.
+ DonateFD(fd int) error
+
+ // Track starts tracking fd. The next call to ReleaseFDs will include fd in
+ // the order it was added. Communicator takes ownership of fd. Client side
+ // should use this for accumulating received FDs.
+ TrackFD(fd int)
+
+ // ReleaseFDs returns the accumulated FDs and stops tracking them. The
+ // ownership of the FDs is transferred to the caller.
+ ReleaseFDs() []int
+}
+
+// fdTracker is a partial implementation of Communicator. It can be embedded in
+// Communicator implementations to keep track of FD donations.
+type fdTracker struct {
+ fds []int
+}
+
+// DonateFD implements Communicator.DonateFD.
+func (d *fdTracker) DonateFD(fd int) error {
+ // Make sure the FD is non-blocking.
+ if err := unix.SetNonblock(fd, true); err != nil {
+ unix.Close(fd)
+ return err
+ }
+ d.TrackFD(fd)
+ return nil
+}
+
+// TrackFD implements Communicator.TrackFD.
+func (d *fdTracker) TrackFD(fd int) {
+ d.fds = append(d.fds, fd)
+}
+
+// ReleaseFDs implements Communicator.ReleaseFDs.
+func (d *fdTracker) ReleaseFDs() []int {
+ ret := d.fds
+ d.fds = d.fds[:0]
+ return ret
+}
diff --git a/pkg/lisafs/connection.go b/pkg/lisafs/connection.go
new file mode 100644
index 000000000..f6e5ecb4f
--- /dev/null
+++ b/pkg/lisafs/connection.go
@@ -0,0 +1,320 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// Connection represents a connection between a mount point in the client and a
+// mount point in the server. It is owned by the server on which it was started
+// and facilitates communication with the client mount.
+//
+// Each connection is set up using a unix domain socket. One end is owned by
+// the server and the other end is owned by the client. The connection may
+// spawn additional comunicational channels for the same mount for increased
+// RPC concurrency.
+type Connection struct {
+ // server is the server on which this connection was created. It is immutably
+ // associated with it for its entire lifetime.
+ server *Server
+
+ // mounted is a one way flag indicating whether this connection has been
+ // mounted correctly and the server is initialized properly.
+ mounted bool
+
+ // readonly indicates if this connection is readonly. All write operations
+ // will fail with EROFS.
+ readonly bool
+
+ // sockComm is the main socket by which this connections is established.
+ sockComm *sockCommunicator
+
+ // channelsMu protects channels.
+ channelsMu sync.Mutex
+ // channels keeps track of all open channels.
+ channels []*channel
+
+ // activeWg represents active channels.
+ activeWg sync.WaitGroup
+
+ // reqGate counts requests that are still being handled.
+ reqGate sync.Gate
+
+ // channelAlloc is used to allocate memory for channels.
+ channelAlloc *flipcall.PacketWindowAllocator
+
+ fdsMu sync.RWMutex
+ // fds keeps tracks of open FDs on this server. It is protected by fdsMu.
+ fds map[FDID]genericFD
+ // nextFDID is the next available FDID. It is protected by fdsMu.
+ nextFDID FDID
+}
+
+// CreateConnection initializes a new connection - creating a server if
+// required. The connection must be started separately.
+func (s *Server) CreateConnection(sock *unet.Socket, readonly bool) (*Connection, error) {
+ c := &Connection{
+ sockComm: newSockComm(sock),
+ server: s,
+ readonly: readonly,
+ channels: make([]*channel, 0, maxChannels()),
+ fds: make(map[FDID]genericFD),
+ nextFDID: InvalidFDID + 1,
+ }
+
+ alloc, err := flipcall.NewPacketWindowAllocator()
+ if err != nil {
+ return nil, err
+ }
+ c.channelAlloc = alloc
+ return c, nil
+}
+
+// Server returns the associated server.
+func (c *Connection) Server() *Server {
+ return c.server
+}
+
+// ServerImpl returns the associated server implementation.
+func (c *Connection) ServerImpl() ServerImpl {
+ return c.server.impl
+}
+
+// Run defines the lifecycle of a connection.
+func (c *Connection) Run() {
+ defer c.close()
+
+ // Start handling requests on this connection.
+ for {
+ m, payloadLen, err := c.sockComm.rcvMsg(0 /* wantFDs */)
+ if err != nil {
+ log.Debugf("sock read failed, closing connection: %v", err)
+ return
+ }
+
+ respM, respPayloadLen, respFDs := c.handleMsg(c.sockComm, m, payloadLen)
+ err = c.sockComm.sndPrepopulatedMsg(respM, respPayloadLen, respFDs)
+ closeFDs(respFDs)
+ if err != nil {
+ log.Debugf("sock write failed, closing connection: %v", err)
+ return
+ }
+ }
+}
+
+// service starts servicing the passed channel until the channel is shutdown.
+// This is a blocking method and hence must be called in a separate goroutine.
+func (c *Connection) service(ch *channel) error {
+ rcvDataLen, err := ch.data.RecvFirst()
+ if err != nil {
+ return err
+ }
+ for rcvDataLen > 0 {
+ m, payloadLen, err := ch.rcvMsg(rcvDataLen)
+ if err != nil {
+ return err
+ }
+ respM, respPayloadLen, respFDs := c.handleMsg(ch, m, payloadLen)
+ numFDs := ch.sendFDs(respFDs)
+ closeFDs(respFDs)
+
+ ch.marshalHdr(respM, numFDs)
+ rcvDataLen, err = ch.data.SendRecv(respPayloadLen + chanHeaderLen)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (c *Connection) respondError(comm Communicator, err unix.Errno) (MID, uint32, []int) {
+ resp := &ErrorResp{errno: uint32(err)}
+ respLen := uint32(resp.SizeBytes())
+ resp.MarshalUnsafe(comm.PayloadBuf(respLen))
+ return Error, respLen, nil
+}
+
+func (c *Connection) handleMsg(comm Communicator, m MID, payloadLen uint32) (MID, uint32, []int) {
+ if !c.reqGate.Enter() {
+ // c.close() has been called; the connection is shutting down.
+ return c.respondError(comm, unix.ECONNRESET)
+ }
+ defer c.reqGate.Leave()
+
+ if !c.mounted && m != Mount {
+ log.Warningf("connection must first be mounted")
+ return c.respondError(comm, unix.EINVAL)
+ }
+
+ // Check if the message is supported for forward compatibility.
+ if int(m) >= len(c.server.handlers) || c.server.handlers[m] == nil {
+ log.Warningf("received request which is not supported by the server, MID = %d", m)
+ return c.respondError(comm, unix.EOPNOTSUPP)
+ }
+
+ // Try handling the request.
+ respPayloadLen, err := c.server.handlers[m](c, comm, payloadLen)
+ fds := comm.ReleaseFDs()
+ if err != nil {
+ closeFDs(fds)
+ return c.respondError(comm, p9.ExtractErrno(err))
+ }
+
+ return m, respPayloadLen, fds
+}
+
+func (c *Connection) close() {
+ // Wait for completion of all inflight requests. This is mostly so that if
+ // a request is stuck, the sandbox supervisor has the opportunity to kill
+ // us with SIGABRT to get a stack dump of the offending handler.
+ c.reqGate.Close()
+
+ // Shutdown and clean up channels.
+ c.channelsMu.Lock()
+ for _, ch := range c.channels {
+ ch.shutdown()
+ }
+ c.activeWg.Wait()
+ for _, ch := range c.channels {
+ ch.destroy()
+ }
+ // This is to prevent additional channels from being created.
+ c.channels = nil
+ c.channelsMu.Unlock()
+
+ // Free the channel memory.
+ if c.channelAlloc != nil {
+ c.channelAlloc.Destroy()
+ }
+
+ // Ensure the connection is closed.
+ c.sockComm.destroy()
+
+ // Cleanup all FDs.
+ c.fdsMu.Lock()
+ for fdid := range c.fds {
+ fd := c.removeFDLocked(fdid)
+ fd.DecRef(nil) // Drop the ref held by c.
+ }
+ c.fdsMu.Unlock()
+}
+
+// The caller gains a ref on the FD on success.
+func (c *Connection) lookupFD(id FDID) (genericFD, error) {
+ c.fdsMu.RLock()
+ defer c.fdsMu.RUnlock()
+
+ fd, ok := c.fds[id]
+ if !ok {
+ return nil, unix.EBADF
+ }
+ fd.IncRef()
+ return fd, nil
+}
+
+// LookupControlFD retrieves the control FD identified by id on this
+// connection. On success, the caller gains a ref on the FD.
+func (c *Connection) LookupControlFD(id FDID) (*ControlFD, error) {
+ fd, err := c.lookupFD(id)
+ if err != nil {
+ return nil, err
+ }
+
+ cfd, ok := fd.(*ControlFD)
+ if !ok {
+ fd.DecRef(nil)
+ return nil, unix.EINVAL
+ }
+ return cfd, nil
+}
+
+// LookupOpenFD retrieves the open FD identified by id on this
+// connection. On success, the caller gains a ref on the FD.
+func (c *Connection) LookupOpenFD(id FDID) (*OpenFD, error) {
+ fd, err := c.lookupFD(id)
+ if err != nil {
+ return nil, err
+ }
+
+ ofd, ok := fd.(*OpenFD)
+ if !ok {
+ fd.DecRef(nil)
+ return nil, unix.EINVAL
+ }
+ return ofd, nil
+}
+
+// insertFD inserts the passed fd into the internal datastructure to track FDs.
+// The caller must hold a ref on fd which is transferred to the connection.
+func (c *Connection) insertFD(fd genericFD) FDID {
+ c.fdsMu.Lock()
+ defer c.fdsMu.Unlock()
+
+ res := c.nextFDID
+ c.nextFDID++
+ if c.nextFDID < res {
+ panic("ran out of FDIDs")
+ }
+ c.fds[res] = fd
+ return res
+}
+
+// RemoveFD makes c stop tracking the passed FDID and drops its ref on it.
+func (c *Connection) RemoveFD(id FDID) {
+ c.fdsMu.Lock()
+ fd := c.removeFDLocked(id)
+ c.fdsMu.Unlock()
+ if fd != nil {
+ // Drop the ref held by c. This can take arbitrarily long. So do not hold
+ // c.fdsMu while calling it.
+ fd.DecRef(nil)
+ }
+}
+
+// RemoveControlFDLocked is the same as RemoveFD with added preconditions.
+//
+// Preconditions:
+// * server's rename mutex must at least be read locked.
+// * id must be pointing to a control FD.
+func (c *Connection) RemoveControlFDLocked(id FDID) {
+ c.fdsMu.Lock()
+ fd := c.removeFDLocked(id)
+ c.fdsMu.Unlock()
+ if fd != nil {
+ // Drop the ref held by c. This can take arbitrarily long. So do not hold
+ // c.fdsMu while calling it.
+ fd.(*ControlFD).DecRefLocked()
+ }
+}
+
+// removeFDLocked makes c stop tracking the passed FDID. Note that the caller
+// must drop ref on the returned fd (preferably without holding c.fdsMu).
+//
+// Precondition: c.fdsMu is locked.
+func (c *Connection) removeFDLocked(id FDID) genericFD {
+ fd := c.fds[id]
+ if fd == nil {
+ log.Warningf("removeFDLocked called on non-existent FDID %d", id)
+ return nil
+ }
+ delete(c.fds, id)
+ return fd
+}
diff --git a/pkg/lisafs/connection_test.go b/pkg/lisafs/connection_test.go
new file mode 100644
index 000000000..28ba47112
--- /dev/null
+++ b/pkg/lisafs/connection_test.go
@@ -0,0 +1,194 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package connection_test
+
+import (
+ "reflect"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/lisafs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+const (
+ dynamicMsgID = lisafs.Channel + 1
+ versionMsgID = dynamicMsgID + 1
+)
+
+var handlers = [...]lisafs.RPCHandler{
+ lisafs.Error: lisafs.ErrorHandler,
+ lisafs.Mount: lisafs.MountHandler,
+ lisafs.Channel: lisafs.ChannelHandler,
+ dynamicMsgID: dynamicMsgHandler,
+ versionMsgID: versionHandler,
+}
+
+// testServer implements lisafs.ServerImpl.
+type testServer struct {
+ lisafs.Server
+}
+
+var _ lisafs.ServerImpl = (*testServer)(nil)
+
+type testControlFD struct {
+ lisafs.ControlFD
+ lisafs.ControlFDImpl
+}
+
+func (fd *testControlFD) FD() *lisafs.ControlFD {
+ return &fd.ControlFD
+}
+
+// Mount implements lisafs.Mount.
+func (s *testServer) Mount(c *lisafs.Connection, mountPath string) (lisafs.ControlFDImpl, lisafs.Inode, error) {
+ return &testControlFD{}, lisafs.Inode{ControlFD: 1}, nil
+}
+
+// MaxMessageSize implements lisafs.MaxMessageSize.
+func (s *testServer) MaxMessageSize() uint32 {
+ return lisafs.MaxMessageSize()
+}
+
+// SupportedMessages implements lisafs.ServerImpl.SupportedMessages.
+func (s *testServer) SupportedMessages() []lisafs.MID {
+ return []lisafs.MID{
+ lisafs.Mount,
+ lisafs.Channel,
+ dynamicMsgID,
+ versionMsgID,
+ }
+}
+
+func runServerClient(t testing.TB, clientFn func(c *lisafs.Client)) {
+ serverSocket, clientSocket, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+
+ ts := &testServer{}
+ ts.Server.InitTestOnly(ts, handlers[:])
+ conn, err := ts.CreateConnection(serverSocket, false /* readonly */)
+ if err != nil {
+ t.Fatalf("starting connection failed: %v", err)
+ return
+ }
+ ts.StartConnection(conn)
+
+ c, _, err := lisafs.NewClient(clientSocket, "/")
+ if err != nil {
+ t.Fatalf("client creation failed: %v", err)
+ }
+
+ clientFn(c)
+
+ c.Close() // This should trigger client and server shutdown.
+ ts.Wait()
+}
+
+// TestStartUp tests that the server and client can be started up correctly.
+func TestStartUp(t *testing.T) {
+ runServerClient(t, func(c *lisafs.Client) {
+ if c.IsSupported(lisafs.Error) {
+ t.Errorf("sending error messages should not be supported")
+ }
+ })
+}
+
+func TestUnsupportedMessage(t *testing.T) {
+ unsupportedM := lisafs.MID(len(handlers))
+ runServerClient(t, func(c *lisafs.Client) {
+ if err := c.SndRcvMessage(unsupportedM, 0, lisafs.NoopMarshal, lisafs.NoopUnmarshal, nil); err != unix.EOPNOTSUPP {
+ t.Errorf("expected EOPNOTSUPP but got err: %v", err)
+ }
+ })
+}
+
+func dynamicMsgHandler(c *lisafs.Connection, comm lisafs.Communicator, payloadLen uint32) (uint32, error) {
+ var req lisafs.MsgDynamic
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ // Just echo back the message.
+ respPayloadLen := uint32(req.SizeBytes())
+ req.MarshalBytes(comm.PayloadBuf(respPayloadLen))
+ return respPayloadLen, nil
+}
+
+// TestStress stress tests sending many messages from various goroutines.
+func TestStress(t *testing.T) {
+ runServerClient(t, func(c *lisafs.Client) {
+ concurrency := 8
+ numMsgPerGoroutine := 5000
+ var clientWg sync.WaitGroup
+ for i := 0; i < concurrency; i++ {
+ clientWg.Add(1)
+ go func() {
+ defer clientWg.Done()
+
+ for j := 0; j < numMsgPerGoroutine; j++ {
+ // Create a massive random message.
+ var req lisafs.MsgDynamic
+ req.Randomize(100)
+
+ var resp lisafs.MsgDynamic
+ if err := c.SndRcvMessage(dynamicMsgID, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil); err != nil {
+ t.Errorf("SndRcvMessage: received unexpected error %v", err)
+ return
+ }
+ if !reflect.DeepEqual(&req, &resp) {
+ t.Errorf("response should be the same as request: request = %+v, response = %+v", req, resp)
+ }
+ }
+ }()
+ }
+
+ clientWg.Wait()
+ })
+}
+
+func versionHandler(c *lisafs.Connection, comm lisafs.Communicator, payloadLen uint32) (uint32, error) {
+ // To be fair, usually handlers will create their own objects and return a
+ // pointer to those. Might be tempting to reuse above variables, but don't.
+ var rv lisafs.P9Version
+ rv.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ // Create a new response.
+ sv := lisafs.P9Version{
+ MSize: rv.MSize,
+ Version: "9P2000.L.Google.11",
+ }
+ respPayloadLen := uint32(sv.SizeBytes())
+ sv.MarshalBytes(comm.PayloadBuf(respPayloadLen))
+ return respPayloadLen, nil
+}
+
+// BenchmarkSendRecv exists to compete against p9's BenchmarkSendRecvChannel.
+func BenchmarkSendRecv(b *testing.B) {
+ b.ReportAllocs()
+ sendV := lisafs.P9Version{
+ MSize: 1 << 20,
+ Version: "9P2000.L.Google.12",
+ }
+
+ var recvV lisafs.P9Version
+ runServerClient(b, func(c *lisafs.Client) {
+ for i := 0; i < b.N; i++ {
+ if err := c.SndRcvMessage(versionMsgID, uint32(sendV.SizeBytes()), sendV.MarshalBytes, recvV.UnmarshalBytes, nil); err != nil {
+ b.Fatalf("unexpected error occurred: %v", err)
+ }
+ }
+ })
+}
diff --git a/pkg/lisafs/fd.go b/pkg/lisafs/fd.go
new file mode 100644
index 000000000..cc6919a1b
--- /dev/null
+++ b/pkg/lisafs/fd.go
@@ -0,0 +1,374 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// FDID (file descriptor identifier) is used to identify FDs on a connection.
+// Each connection has its own FDID namespace.
+//
+// +marshal slice:FDIDSlice
+type FDID uint32
+
+// InvalidFDID represents an invalid FDID.
+const InvalidFDID FDID = 0
+
+// Ok returns true if f is a valid FDID.
+func (f FDID) Ok() bool {
+ return f != InvalidFDID
+}
+
+// genericFD can represent a ControlFD or OpenFD.
+type genericFD interface {
+ refsvfs2.RefCounter
+}
+
+// A ControlFD is the gateway to the backing filesystem tree node. It is an
+// unusual concept. This exists to provide a safe way to do path-based
+// operations on the file. It performs operations that can modify the
+// filesystem tree and synchronizes these operations. See ControlFDImpl for
+// supported operations.
+//
+// It is not an inode, because multiple control FDs are allowed to exist on the
+// same file. It is not a file descriptor because it is not tied to any access
+// mode, i.e. a control FD can change its access mode based on the operation
+// being performed.
+//
+// Reference Model:
+// * When a control FD is created, the connection takes a ref on it which
+// represents the client's ref on the FD.
+// * The client can drop its ref via the Close RPC which will in turn make the
+// connection drop its ref.
+// * Each control FD holds a ref on its parent for its entire life time.
+type ControlFD struct {
+ controlFDRefs
+ controlFDEntry
+
+ // parent is the parent directory FD containing the file this FD represents.
+ // A ControlFD holds a ref on parent for its entire lifetime. If this FD
+ // represents the root, then parent is nil. parent may be a control FD from
+ // another connection (another mount point). parent is protected by the
+ // backing server's rename mutex.
+ parent *ControlFD
+
+ // name is the file path's last component name. If this FD represents the
+ // root directory, then name is the mount path. name is protected by the
+ // backing server's rename mutex.
+ name string
+
+ // children is a linked list of all children control FDs. As per reference
+ // model, all children hold a ref on this FD.
+ // children is protected by childrenMu and server's rename mutex. To have
+ // mutual exclusion, it is sufficient to:
+ // * Hold rename mutex for reading and lock childrenMu. OR
+ // * Or hold rename mutex for writing.
+ childrenMu sync.Mutex
+ children controlFDList
+
+ // openFDs is a linked list of all FDs opened on this FD. As per reference
+ // model, all open FDs hold a ref on this FD.
+ openFDsMu sync.RWMutex
+ openFDs openFDList
+
+ // All the following fields are immutable.
+
+ // id is the unique FD identifier which identifies this FD on its connection.
+ id FDID
+
+ // conn is the backing connection owning this FD.
+ conn *Connection
+
+ // ftype is the file type of the backing inode. ftype.FileType() == ftype.
+ ftype linux.FileMode
+
+ // impl is the control FD implementation which embeds this struct. It
+ // contains all the implementation specific details.
+ impl ControlFDImpl
+}
+
+var _ genericFD = (*ControlFD)(nil)
+
+// DecRef implements refsvfs2.RefCounter.DecRef. Note that the context
+// parameter should never be used. It exists solely to comply with the
+// refsvfs2.RefCounter interface.
+func (fd *ControlFD) DecRef(context.Context) {
+ fd.controlFDRefs.DecRef(func() {
+ if fd.parent != nil {
+ fd.conn.server.RenameMu.RLock()
+ fd.parent.childrenMu.Lock()
+ fd.parent.children.Remove(fd)
+ fd.parent.childrenMu.Unlock()
+ fd.conn.server.RenameMu.RUnlock()
+ fd.parent.DecRef(nil) // Drop the ref on the parent.
+ }
+ fd.impl.Close(fd.conn)
+ })
+}
+
+// DecRefLocked is the same as DecRef except the added precondition.
+//
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) DecRefLocked() {
+ fd.controlFDRefs.DecRef(func() {
+ fd.clearParentLocked()
+ fd.impl.Close(fd.conn)
+ })
+}
+
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) clearParentLocked() {
+ if fd.parent == nil {
+ return
+ }
+ fd.parent.childrenMu.Lock()
+ fd.parent.children.Remove(fd)
+ fd.parent.childrenMu.Unlock()
+ fd.parent.DecRefLocked() // Drop the ref on the parent.
+}
+
+// Init must be called before first use of fd. It inserts fd into the
+// filesystem tree.
+//
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) Init(c *Connection, parent *ControlFD, name string, mode linux.FileMode, impl ControlFDImpl) {
+ // Initialize fd with 1 ref which is transferred to c via c.insertFD().
+ fd.controlFDRefs.InitRefs()
+ fd.conn = c
+ fd.id = c.insertFD(fd)
+ fd.name = name
+ fd.ftype = mode.FileType()
+ fd.impl = impl
+ fd.setParentLocked(parent)
+}
+
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) setParentLocked(parent *ControlFD) {
+ fd.parent = parent
+ if parent != nil {
+ parent.IncRef() // Hold a ref on parent.
+ parent.childrenMu.Lock()
+ parent.children.PushBack(fd)
+ parent.childrenMu.Unlock()
+ }
+}
+
+// FileType returns the file mode only containing the file type bits.
+func (fd *ControlFD) FileType() linux.FileMode {
+ return fd.ftype
+}
+
+// IsDir indicates whether fd represents a directory.
+func (fd *ControlFD) IsDir() bool {
+ return fd.ftype == unix.S_IFDIR
+}
+
+// IsRegular indicates whether fd represents a regular file.
+func (fd *ControlFD) IsRegular() bool {
+ return fd.ftype == unix.S_IFREG
+}
+
+// IsSymlink indicates whether fd represents a symbolic link.
+func (fd *ControlFD) IsSymlink() bool {
+ return fd.ftype == unix.S_IFLNK
+}
+
+// IsSocket indicates whether fd represents a socket.
+func (fd *ControlFD) IsSocket() bool {
+ return fd.ftype == unix.S_IFSOCK
+}
+
+// NameLocked returns the backing file's last component name.
+//
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) NameLocked() string {
+ return fd.name
+}
+
+// ParentLocked returns the parent control FD. Note that parent might be a
+// control FD from another connection on this server. So its ID must not
+// returned on this connection because FDIDs are local to their connection.
+//
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) ParentLocked() ControlFDImpl {
+ if fd.parent == nil {
+ return nil
+ }
+ return fd.parent.impl
+}
+
+// ID returns fd's ID.
+func (fd *ControlFD) ID() FDID {
+ return fd.id
+}
+
+// FilePath returns the absolute path of the file fd was opened on. This is
+// expensive and must not be called on hot paths. FilePath acquires the rename
+// mutex for reading so callers should not be holding it.
+func (fd *ControlFD) FilePath() string {
+ // Lock the rename mutex for reading to ensure that the filesystem tree is not
+ // changed while we traverse it upwards.
+ fd.conn.server.RenameMu.RLock()
+ defer fd.conn.server.RenameMu.RUnlock()
+ return fd.FilePathLocked()
+}
+
+// FilePathLocked is the same as FilePath with the additional precondition.
+//
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) FilePathLocked() string {
+ // Walk upwards and prepend name to res.
+ var res fspath.Builder
+ for fd != nil {
+ res.PrependComponent(fd.name)
+ fd = fd.parent
+ }
+ return res.String()
+}
+
+// ForEachOpenFD executes fn on each FD opened on fd.
+func (fd *ControlFD) ForEachOpenFD(fn func(ofd OpenFDImpl)) {
+ fd.openFDsMu.RLock()
+ defer fd.openFDsMu.RUnlock()
+ for ofd := fd.openFDs.Front(); ofd != nil; ofd = ofd.Next() {
+ fn(ofd.impl)
+ }
+}
+
+// OpenFD represents an open file descriptor on the protocol. It resonates
+// closely with a Linux file descriptor. Its operations are limited to the
+// file. Its operations are not allowed to modify or traverse the filesystem
+// tree. See OpenFDImpl for the supported operations.
+//
+// Reference Model:
+// * An OpenFD takes a reference on the control FD it was opened on.
+type OpenFD struct {
+ openFDRefs
+ openFDEntry
+
+ // All the following fields are immutable.
+
+ // controlFD is the ControlFD on which this FD was opened. OpenFD holds a ref
+ // on controlFD for its entire lifetime.
+ controlFD *ControlFD
+
+ // id is the unique FD identifier which identifies this FD on its connection.
+ id FDID
+
+ // Access mode for this FD.
+ readable bool
+ writable bool
+
+ // impl is the open FD implementation which embeds this struct. It
+ // contains all the implementation specific details.
+ impl OpenFDImpl
+}
+
+var _ genericFD = (*OpenFD)(nil)
+
+// ID returns fd's ID.
+func (fd *OpenFD) ID() FDID {
+ return fd.id
+}
+
+// ControlFD returns the control FD on which this FD was opened.
+func (fd *OpenFD) ControlFD() ControlFDImpl {
+ return fd.controlFD.impl
+}
+
+// DecRef implements refsvfs2.RefCounter.DecRef. Note that the context
+// parameter should never be used. It exists solely to comply with the
+// refsvfs2.RefCounter interface.
+func (fd *OpenFD) DecRef(context.Context) {
+ fd.openFDRefs.DecRef(func() {
+ fd.controlFD.openFDsMu.Lock()
+ fd.controlFD.openFDs.Remove(fd)
+ fd.controlFD.openFDsMu.Unlock()
+ fd.controlFD.DecRef(nil) // Drop the ref on the control FD.
+ fd.impl.Close(fd.controlFD.conn)
+ })
+}
+
+// Init must be called before first use of fd.
+func (fd *OpenFD) Init(cfd *ControlFD, flags uint32, impl OpenFDImpl) {
+ // Initialize fd with 1 ref which is transferred to c via c.insertFD().
+ fd.openFDRefs.InitRefs()
+ fd.controlFD = cfd
+ fd.id = cfd.conn.insertFD(fd)
+ accessMode := flags & unix.O_ACCMODE
+ fd.readable = accessMode == unix.O_RDONLY || accessMode == unix.O_RDWR
+ fd.writable = accessMode == unix.O_WRONLY || accessMode == unix.O_RDWR
+ fd.impl = impl
+ cfd.IncRef() // Holds a ref on cfd for its lifetime.
+ cfd.openFDsMu.Lock()
+ cfd.openFDs.PushBack(fd)
+ cfd.openFDsMu.Unlock()
+}
+
+// ControlFDImpl contains implementation details for a ControlFD.
+// Implementations of ControlFDImpl should contain their associated ControlFD
+// by value as their first field.
+//
+// The operations that perform path traversal or any modification to the
+// filesystem tree must synchronize those modifications with the server's
+// rename mutex.
+type ControlFDImpl interface {
+ FD() *ControlFD
+ Close(c *Connection)
+ Stat(c *Connection, comm Communicator) (uint32, error)
+ SetStat(c *Connection, comm Communicator, stat SetStatReq) (uint32, error)
+ Walk(c *Connection, comm Communicator, path StringArray) (uint32, error)
+ WalkStat(c *Connection, comm Communicator, path StringArray) (uint32, error)
+ Open(c *Connection, comm Communicator, flags uint32) (uint32, error)
+ OpenCreate(c *Connection, comm Communicator, mode linux.FileMode, uid UID, gid GID, name string, flags uint32) (uint32, error)
+ Mkdir(c *Connection, comm Communicator, mode linux.FileMode, uid UID, gid GID, name string) (uint32, error)
+ Mknod(c *Connection, comm Communicator, mode linux.FileMode, uid UID, gid GID, name string, minor uint32, major uint32) (uint32, error)
+ Symlink(c *Connection, comm Communicator, name string, target string, uid UID, gid GID) (uint32, error)
+ Link(c *Connection, comm Communicator, dir ControlFDImpl, name string) (uint32, error)
+ StatFS(c *Connection, comm Communicator) (uint32, error)
+ Readlink(c *Connection, comm Communicator) (uint32, error)
+ Connect(c *Connection, comm Communicator, sockType uint32) error
+ Unlink(c *Connection, name string, flags uint32) error
+ RenameLocked(c *Connection, newDir ControlFDImpl, newName string) (func(ControlFDImpl), func(), error)
+ GetXattr(c *Connection, comm Communicator, name string, size uint32) (uint32, error)
+ SetXattr(c *Connection, name string, value string, flags uint32) error
+ ListXattr(c *Connection, comm Communicator, size uint64) (uint32, error)
+ RemoveXattr(c *Connection, comm Communicator, name string) error
+}
+
+// OpenFDImpl contains implementation details for a OpenFD. Implementations of
+// OpenFDImpl should contain their associated OpenFD by value as their first
+// field.
+//
+// Since these operations do not perform any path traversal or any modification
+// to the filesystem tree, there is no need to synchronize with rename
+// operations.
+type OpenFDImpl interface {
+ FD() *OpenFD
+ Close(c *Connection)
+ Stat(c *Connection, comm Communicator) (uint32, error)
+ Sync(c *Connection) error
+ Write(c *Connection, comm Communicator, buf []byte, off uint64) (uint32, error)
+ Read(c *Connection, comm Communicator, off uint64, count uint32) (uint32, error)
+ Allocate(c *Connection, mode, off, length uint64) error
+ Flush(c *Connection) error
+ Getdent64(c *Connection, comm Communicator, count uint32, seek0 bool) (uint32, error)
+}
diff --git a/pkg/lisafs/handlers.go b/pkg/lisafs/handlers.go
new file mode 100644
index 000000000..82807734d
--- /dev/null
+++ b/pkg/lisafs/handlers.go
@@ -0,0 +1,768 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "fmt"
+ "path"
+ "path/filepath"
+ "strings"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
+)
+
+const (
+ allowedOpenFlags = unix.O_ACCMODE | unix.O_TRUNC
+ setStatSupportedMask = unix.STATX_MODE | unix.STATX_UID | unix.STATX_GID | unix.STATX_SIZE | unix.STATX_ATIME | unix.STATX_MTIME
+)
+
+// RPCHandler defines a handler that is invoked when the associated message is
+// received. The handler is responsible for:
+//
+// * Unmarshalling the request from the passed payload and interpreting it.
+// * Marshalling the response into the communicator's payload buffer.
+// * Return the number of payload bytes written.
+// * Donate any FDs (if needed) to comm which will in turn donate it to client.
+type RPCHandler func(c *Connection, comm Communicator, payloadLen uint32) (uint32, error)
+
+var handlers = [...]RPCHandler{
+ Error: ErrorHandler,
+ Mount: MountHandler,
+ Channel: ChannelHandler,
+ FStat: FStatHandler,
+ SetStat: SetStatHandler,
+ Walk: WalkHandler,
+ WalkStat: WalkStatHandler,
+ OpenAt: OpenAtHandler,
+ OpenCreateAt: OpenCreateAtHandler,
+ Close: CloseHandler,
+ FSync: FSyncHandler,
+ PWrite: PWriteHandler,
+ PRead: PReadHandler,
+ MkdirAt: MkdirAtHandler,
+ MknodAt: MknodAtHandler,
+ SymlinkAt: SymlinkAtHandler,
+ LinkAt: LinkAtHandler,
+ FStatFS: FStatFSHandler,
+ FAllocate: FAllocateHandler,
+ ReadLinkAt: ReadLinkAtHandler,
+ Flush: FlushHandler,
+ Connect: ConnectHandler,
+ UnlinkAt: UnlinkAtHandler,
+ RenameAt: RenameAtHandler,
+ Getdents64: Getdents64Handler,
+ FGetXattr: FGetXattrHandler,
+ FSetXattr: FSetXattrHandler,
+ FListXattr: FListXattrHandler,
+ FRemoveXattr: FRemoveXattrHandler,
+}
+
+// ErrorHandler handles Error message.
+func ErrorHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ // Client should never send Error.
+ return 0, unix.EINVAL
+}
+
+// MountHandler handles the Mount RPC. Note that there can not be concurrent
+// executions of MountHandler on a connection because the connection enforces
+// that Mount is the first message on the connection. Only after the connection
+// has been successfully mounted can other channels be created.
+func MountHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req MountReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ mountPath := path.Clean(string(req.MountPath))
+ if !filepath.IsAbs(mountPath) {
+ log.Warningf("mountPath %q is not absolute", mountPath)
+ return 0, unix.EINVAL
+ }
+
+ if c.mounted {
+ log.Warningf("connection has already been mounted at %q", mountPath)
+ return 0, unix.EBUSY
+ }
+
+ rootFD, rootIno, err := c.ServerImpl().Mount(c, mountPath)
+ if err != nil {
+ return 0, err
+ }
+
+ c.server.addMountPoint(rootFD.FD())
+ c.mounted = true
+ resp := MountResp{
+ Root: rootIno,
+ SupportedMs: c.ServerImpl().SupportedMessages(),
+ MaxMessageSize: primitive.Uint32(c.ServerImpl().MaxMessageSize()),
+ }
+ respPayloadLen := uint32(resp.SizeBytes())
+ resp.MarshalBytes(comm.PayloadBuf(respPayloadLen))
+ return respPayloadLen, nil
+}
+
+// ChannelHandler handles the Channel RPC.
+func ChannelHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ ch, desc, fdSock, err := c.createChannel(c.ServerImpl().MaxMessageSize())
+ if err != nil {
+ return 0, err
+ }
+
+ // Start servicing the channel in a separate goroutine.
+ c.activeWg.Add(1)
+ go func() {
+ if err := c.service(ch); err != nil {
+ // Don't log shutdown error which is expected during server shutdown.
+ if _, ok := err.(flipcall.ShutdownError); !ok {
+ log.Warningf("lisafs.Connection.service(channel = @%p): %v", ch, err)
+ }
+ }
+ c.activeWg.Done()
+ }()
+
+ clientDataFD, err := unix.Dup(desc.FD)
+ if err != nil {
+ unix.Close(fdSock)
+ ch.shutdown()
+ return 0, err
+ }
+
+ // Respond to client with successful channel creation message.
+ if err := comm.DonateFD(clientDataFD); err != nil {
+ return 0, err
+ }
+ if err := comm.DonateFD(fdSock); err != nil {
+ return 0, err
+ }
+ resp := ChannelResp{
+ dataOffset: desc.Offset,
+ dataLength: uint64(desc.Length),
+ }
+ respLen := uint32(resp.SizeBytes())
+ resp.MarshalUnsafe(comm.PayloadBuf(respLen))
+ return respLen, nil
+}
+
+// FStatHandler handles the FStat RPC.
+func FStatHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req StatReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.lookupFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+
+ switch t := fd.(type) {
+ case *ControlFD:
+ return t.impl.Stat(c, comm)
+ case *OpenFD:
+ return t.impl.Stat(c, comm)
+ default:
+ panic(fmt.Sprintf("unknown fd type %T", t))
+ }
+}
+
+// SetStatHandler handles the SetStat RPC.
+func SetStatHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+
+ var req SetStatReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+
+ if req.Mask&^setStatSupportedMask != 0 {
+ return 0, unix.EPERM
+ }
+
+ return fd.impl.SetStat(c, comm, req)
+}
+
+// WalkHandler handles the Walk RPC.
+func WalkHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req WalkReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+ for _, name := range req.Path {
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+ }
+
+ return fd.impl.Walk(c, comm, req.Path)
+}
+
+// WalkStatHandler handles the WalkStat RPC.
+func WalkStatHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req WalkReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+
+ // Note that this fd is allowed to not actually be a directory when the
+ // only path component to walk is "" (self).
+ if !fd.IsDir() {
+ if len(req.Path) > 1 || (len(req.Path) == 1 && len(req.Path[0]) > 0) {
+ return 0, unix.ENOTDIR
+ }
+ }
+ for i, name := range req.Path {
+ // First component is allowed to be "".
+ if i == 0 && len(name) == 0 {
+ continue
+ }
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+ }
+
+ return fd.impl.WalkStat(c, comm, req.Path)
+}
+
+// OpenAtHandler handles the OpenAt RPC.
+func OpenAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req OpenAtReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ // Only keep allowed open flags.
+ if allowedFlags := req.Flags & allowedOpenFlags; allowedFlags != req.Flags {
+ log.Debugf("discarding open flags that are not allowed: old open flags = %d, new open flags = %d", req.Flags, allowedFlags)
+ req.Flags = allowedFlags
+ }
+
+ accessMode := req.Flags & unix.O_ACCMODE
+ trunc := req.Flags&unix.O_TRUNC != 0
+ if c.readonly && (accessMode != unix.O_RDONLY || trunc) {
+ return 0, unix.EROFS
+ }
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if fd.IsDir() {
+ // Directory is not truncatable and must be opened with O_RDONLY.
+ if accessMode != unix.O_RDONLY || trunc {
+ return 0, unix.EISDIR
+ }
+ }
+
+ return fd.impl.Open(c, comm, req.Flags)
+}
+
+// OpenCreateAtHandler handles the OpenCreateAt RPC.
+func OpenCreateAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req OpenCreateAtReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ // Only keep allowed open flags.
+ if allowedFlags := req.Flags & allowedOpenFlags; allowedFlags != req.Flags {
+ log.Debugf("discarding open flags that are not allowed: old open flags = %d, new open flags = %d", req.Flags, allowedFlags)
+ req.Flags = allowedFlags
+ }
+
+ name := string(req.Name)
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+
+ return fd.impl.OpenCreate(c, comm, req.Mode, req.UID, req.GID, name, uint32(req.Flags))
+}
+
+// CloseHandler handles the Close RPC.
+func CloseHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req CloseReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+ for _, fd := range req.FDs {
+ c.RemoveFD(fd)
+ }
+
+ // There is no response message for this.
+ return 0, nil
+}
+
+// FSyncHandler handles the FSync RPC.
+func FSyncHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req FsyncReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ // Return the first error we encounter, but sync everything we can
+ // regardless.
+ var retErr error
+ for _, fdid := range req.FDs {
+ if err := c.fsyncFD(fdid); err != nil && retErr == nil {
+ retErr = err
+ }
+ }
+
+ // There is no response message for this.
+ return 0, retErr
+}
+
+func (c *Connection) fsyncFD(id FDID) error {
+ fd, err := c.LookupOpenFD(id)
+ if err != nil {
+ return err
+ }
+ return fd.impl.Sync(c)
+}
+
+// PWriteHandler handles the PWrite RPC.
+func PWriteHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req PWriteReq
+ // Note that it is an optimized Unmarshal operation which avoids any buffer
+ // allocation and copying. req.Buf just points to payload. This is safe to do
+ // as the handler owns payload and req's lifetime is limited to the handler.
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+ fd, err := c.LookupOpenFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ if !fd.writable {
+ return 0, unix.EBADF
+ }
+ return fd.impl.Write(c, comm, req.Buf, uint64(req.Offset))
+}
+
+// PReadHandler handles the PRead RPC.
+func PReadHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req PReadReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupOpenFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.readable {
+ return 0, unix.EBADF
+ }
+ return fd.impl.Read(c, comm, req.Offset, req.Count)
+}
+
+// MkdirAtHandler handles the MkdirAt RPC.
+func MkdirAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req MkdirAtReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ name := string(req.Name)
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+ return fd.impl.Mkdir(c, comm, req.Mode, req.UID, req.GID, name)
+}
+
+// MknodAtHandler handles the MknodAt RPC.
+func MknodAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req MknodAtReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ name := string(req.Name)
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+ return fd.impl.Mknod(c, comm, req.Mode, req.UID, req.GID, name, uint32(req.Minor), uint32(req.Major))
+}
+
+// SymlinkAtHandler handles the SymlinkAt RPC.
+func SymlinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req SymlinkAtReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ name := string(req.Name)
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+ return fd.impl.Symlink(c, comm, name, string(req.Target), req.UID, req.GID)
+}
+
+// LinkAtHandler handles the LinkAt RPC.
+func LinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req LinkAtReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ name := string(req.Name)
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+
+ targetFD, err := c.LookupControlFD(req.Target)
+ if err != nil {
+ return 0, err
+ }
+ return targetFD.impl.Link(c, comm, fd.impl, name)
+}
+
+// FStatFSHandler handles the FStatFS RPC.
+func FStatFSHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req FStatFSReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ return fd.impl.StatFS(c, comm)
+}
+
+// FAllocateHandler handles the FAllocate RPC.
+func FAllocateHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req FAllocateReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupOpenFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.writable {
+ return 0, unix.EBADF
+ }
+ return 0, fd.impl.Allocate(c, req.Mode, req.Offset, req.Length)
+}
+
+// ReadLinkAtHandler handles the ReadLinkAt RPC.
+func ReadLinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req ReadLinkAtReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsSymlink() {
+ return 0, unix.EINVAL
+ }
+ return fd.impl.Readlink(c, comm)
+}
+
+// FlushHandler handles the Flush RPC.
+func FlushHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req FlushReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupOpenFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+
+ return 0, fd.impl.Flush(c)
+}
+
+// ConnectHandler handles the Connect RPC.
+func ConnectHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req ConnectReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsSocket() {
+ return 0, unix.ENOTSOCK
+ }
+ return 0, fd.impl.Connect(c, comm, req.SockType)
+}
+
+// UnlinkAtHandler handles the UnlinkAt RPC.
+func UnlinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req UnlinkAtReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ name := string(req.Name)
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+ return 0, fd.impl.Unlink(c, name, uint32(req.Flags))
+}
+
+// RenameAtHandler handles the RenameAt RPC.
+func RenameAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req RenameAtReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ newName := string(req.NewName)
+ if err := checkSafeName(newName); err != nil {
+ return 0, err
+ }
+
+ renamed, err := c.LookupControlFD(req.Renamed)
+ if err != nil {
+ return 0, err
+ }
+ defer renamed.DecRef(nil)
+
+ newDir, err := c.LookupControlFD(req.NewDir)
+ if err != nil {
+ return 0, err
+ }
+ defer newDir.DecRef(nil)
+ if !newDir.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+
+ // Hold RenameMu for writing during rename, this is important.
+ c.server.RenameMu.Lock()
+ defer c.server.RenameMu.Unlock()
+
+ if renamed.parent == nil {
+ // renamed is root.
+ return 0, unix.EBUSY
+ }
+
+ oldParentPath := renamed.parent.FilePathLocked()
+ oldPath := oldParentPath + "/" + renamed.name
+ if newName == renamed.name && oldParentPath == newDir.FilePathLocked() {
+ // Nothing to do.
+ return 0, nil
+ }
+
+ updateControlFD, cleanUp, err := renamed.impl.RenameLocked(c, newDir.impl, newName)
+ if err != nil {
+ return 0, err
+ }
+
+ c.server.forEachMountPoint(func(root *ControlFD) {
+ if !strings.HasPrefix(oldPath, root.name) {
+ return
+ }
+ pit := fspath.Parse(oldPath[len(root.name):]).Begin
+ root.renameRecursiveLocked(newDir, newName, pit, updateControlFD)
+ })
+
+ if cleanUp != nil {
+ cleanUp()
+ }
+ return 0, nil
+}
+
+// Precondition: rename mutex must be locked for writing.
+func (fd *ControlFD) renameRecursiveLocked(newDir *ControlFD, newName string, pit fspath.Iterator, updateControlFD func(ControlFDImpl)) {
+ if !pit.Ok() {
+ // fd should be renamed.
+ fd.clearParentLocked()
+ fd.setParentLocked(newDir)
+ fd.name = newName
+ if updateControlFD != nil {
+ updateControlFD(fd.impl)
+ }
+ return
+ }
+
+ cur := pit.String()
+ next := pit.Next()
+ // No need to hold fd.childrenMu because RenameMu is locked for writing.
+ for child := fd.children.Front(); child != nil; child = child.Next() {
+ if child.name == cur {
+ child.renameRecursiveLocked(newDir, newName, next, updateControlFD)
+ }
+ }
+}
+
+// Getdents64Handler handles the Getdents64 RPC.
+func Getdents64Handler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req Getdents64Req
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupOpenFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.controlFD.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+
+ seek0 := false
+ if req.Count < 0 {
+ seek0 = true
+ req.Count = -req.Count
+ }
+ return fd.impl.Getdent64(c, comm, uint32(req.Count), seek0)
+}
+
+// FGetXattrHandler handles the FGetXattr RPC.
+func FGetXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req FGetXattrReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ return fd.impl.GetXattr(c, comm, string(req.Name), uint32(req.BufSize))
+}
+
+// FSetXattrHandler handles the FSetXattr RPC.
+func FSetXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req FSetXattrReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ return 0, fd.impl.SetXattr(c, string(req.Name), string(req.Value), uint32(req.Flags))
+}
+
+// FListXattrHandler handles the FListXattr RPC.
+func FListXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req FListXattrReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ return fd.impl.ListXattr(c, comm, req.Size)
+}
+
+// FRemoveXattrHandler handles the FRemoveXattr RPC.
+func FRemoveXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req FRemoveXattrReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ return 0, fd.impl.RemoveXattr(c, comm, string(req.Name))
+}
+
+// checkSafeName validates the name and returns nil or returns an error.
+func checkSafeName(name string) error {
+ if name != "" && !strings.Contains(name, "/") && name != "." && name != ".." {
+ return nil
+ }
+ return unix.EINVAL
+}
diff --git a/pkg/lisafs/lisafs.go b/pkg/lisafs/lisafs.go
new file mode 100644
index 000000000..4d8e956ab
--- /dev/null
+++ b/pkg/lisafs/lisafs.go
@@ -0,0 +1,18 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package lisafs (LInux SAndbox FileSystem) defines the protocol for
+// filesystem RPCs between an untrusted Sandbox (client) and a trusted
+// filesystem server.
+package lisafs
diff --git a/pkg/lisafs/message.go b/pkg/lisafs/message.go
new file mode 100644
index 000000000..722afd0be
--- /dev/null
+++ b/pkg/lisafs/message.go
@@ -0,0 +1,1251 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "math"
+ "os"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
+)
+
+// Messages have two parts:
+// * A transport header used to decipher received messages.
+// * A byte array referred to as "payload" which contains the actual message.
+//
+// "dataLen" refers to the size of both combined.
+
+// MID (message ID) is used to identify messages to parse from payload.
+//
+// +marshal slice:MIDSlice
+type MID uint16
+
+// These constants are used to identify their corresponding message types.
+const (
+ // Error is only used in responses to pass errors to client.
+ Error MID = 0
+
+ // Mount is used to establish connection between the client and server mount
+ // point. lisafs requires that the client makes a successful Mount RPC before
+ // making other RPCs.
+ Mount MID = 1
+
+ // Channel requests to start a new communicational channel.
+ Channel MID = 2
+
+ // FStat requests the stat(2) results for a specified file.
+ FStat MID = 3
+
+ // SetStat requests to change file attributes. Note that there is no one
+ // corresponding Linux syscall. This is a conglomeration of fchmod(2),
+ // fchown(2), ftruncate(2) and futimesat(2).
+ SetStat MID = 4
+
+ // Walk requests to walk the specified path starting from the specified
+ // directory. Server-side path traversal is terminated preemptively on
+ // symlinks entries because they can cause non-linear traversal.
+ Walk MID = 5
+
+ // WalkStat is the same as Walk, except the following differences:
+ // * If the first path component is "", then it also returns stat results
+ // for the directory where the walk starts.
+ // * Does not return Inode, just the Stat results for each path component.
+ WalkStat MID = 6
+
+ // OpenAt is analogous to openat(2). It does not perform any walk. It merely
+ // duplicates the control FD with the open flags passed.
+ OpenAt MID = 7
+
+ // OpenCreateAt is analogous to openat(2) with O_CREAT|O_EXCL added to flags.
+ // It also returns the newly created file inode.
+ OpenCreateAt MID = 8
+
+ // Close is analogous to close(2) but can work on multiple FDs.
+ Close MID = 9
+
+ // FSync is analogous to fsync(2) but can work on multiple FDs.
+ FSync MID = 10
+
+ // PWrite is analogous to pwrite(2).
+ PWrite MID = 11
+
+ // PRead is analogous to pread(2).
+ PRead MID = 12
+
+ // MkdirAt is analogous to mkdirat(2).
+ MkdirAt MID = 13
+
+ // MknodAt is analogous to mknodat(2).
+ MknodAt MID = 14
+
+ // SymlinkAt is analogous to symlinkat(2).
+ SymlinkAt MID = 15
+
+ // LinkAt is analogous to linkat(2).
+ LinkAt MID = 16
+
+ // FStatFS is analogous to fstatfs(2).
+ FStatFS MID = 17
+
+ // FAllocate is analogous to fallocate(2).
+ FAllocate MID = 18
+
+ // ReadLinkAt is analogous to readlinkat(2).
+ ReadLinkAt MID = 19
+
+ // Flush cleans up the file state. Its behavior is implementation
+ // dependent and might not even be supported in server implementations.
+ Flush MID = 20
+
+ // Connect is loosely analogous to connect(2).
+ Connect MID = 21
+
+ // UnlinkAt is analogous to unlinkat(2).
+ UnlinkAt MID = 22
+
+ // RenameAt is loosely analogous to renameat(2).
+ RenameAt MID = 23
+
+ // Getdents64 is analogous to getdents64(2).
+ Getdents64 MID = 24
+
+ // FGetXattr is analogous to fgetxattr(2).
+ FGetXattr MID = 25
+
+ // FSetXattr is analogous to fsetxattr(2).
+ FSetXattr MID = 26
+
+ // FListXattr is analogous to flistxattr(2).
+ FListXattr MID = 27
+
+ // FRemoveXattr is analogous to fremovexattr(2).
+ FRemoveXattr MID = 28
+)
+
+const (
+ // NoUID is a sentinel used to indicate no valid UID.
+ NoUID UID = math.MaxUint32
+
+ // NoGID is a sentinel used to indicate no valid GID.
+ NoGID GID = math.MaxUint32
+)
+
+// MaxMessageSize is the recommended max message size that can be used by
+// connections. Server implementations may choose to use other values.
+func MaxMessageSize() uint32 {
+ // Return HugePageSize - PageSize so that when flipcall packet window is
+ // created with MaxMessageSize() + flipcall header size + channel header
+ // size, HugePageSize is allocated and can be backed by a single huge page
+ // if supported by the underlying memfd.
+ return uint32(hostarch.HugePageSize - os.Getpagesize())
+}
+
+// TODO(gvisor.dev/issue/6450): Once this is resolved:
+// * Update manual implementations and function signatures.
+// * Update RPC handlers and appropriate callers to handle errors correctly.
+// * Update manual implementations to get rid of buffer shifting.
+
+// UID represents a user ID.
+//
+// +marshal
+type UID uint32
+
+// Ok returns true if uid is not NoUID.
+func (uid UID) Ok() bool {
+ return uid != NoUID
+}
+
+// GID represents a group ID.
+//
+// +marshal
+type GID uint32
+
+// Ok returns true if gid is not NoGID.
+func (gid GID) Ok() bool {
+ return gid != NoGID
+}
+
+// NoopMarshal is a noop implementation of marshal.Marshallable.MarshalBytes.
+func NoopMarshal([]byte) {}
+
+// NoopUnmarshal is a noop implementation of marshal.Marshallable.UnmarshalBytes.
+func NoopUnmarshal([]byte) {}
+
+// SizedString represents a string in memory. The marshalled string bytes are
+// preceded by a uint32 signifying the string length.
+type SizedString string
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (s *SizedString) SizeBytes() int {
+ return (*primitive.Uint32)(nil).SizeBytes() + len(*s)
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (s *SizedString) MarshalBytes(dst []byte) {
+ strLen := primitive.Uint32(len(*s))
+ strLen.MarshalUnsafe(dst)
+ dst = dst[strLen.SizeBytes():]
+ // Copy without any allocation.
+ copy(dst[:strLen], *s)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (s *SizedString) UnmarshalBytes(src []byte) {
+ var strLen primitive.Uint32
+ strLen.UnmarshalUnsafe(src)
+ src = src[strLen.SizeBytes():]
+ // Take the hit, this leads to an allocation + memcpy. No way around it.
+ *s = SizedString(src[:strLen])
+}
+
+// StringArray represents an array of SizedStrings in memory. The marshalled
+// array data is preceded by a uint32 signifying the array length.
+type StringArray []string
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (s *StringArray) SizeBytes() int {
+ size := (*primitive.Uint32)(nil).SizeBytes()
+ for _, str := range *s {
+ sstr := SizedString(str)
+ size += sstr.SizeBytes()
+ }
+ return size
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (s *StringArray) MarshalBytes(dst []byte) {
+ arrLen := primitive.Uint32(len(*s))
+ arrLen.MarshalUnsafe(dst)
+ dst = dst[arrLen.SizeBytes():]
+ for _, str := range *s {
+ sstr := SizedString(str)
+ sstr.MarshalBytes(dst)
+ dst = dst[sstr.SizeBytes():]
+ }
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (s *StringArray) UnmarshalBytes(src []byte) {
+ var arrLen primitive.Uint32
+ arrLen.UnmarshalUnsafe(src)
+ src = src[arrLen.SizeBytes():]
+
+ if cap(*s) < int(arrLen) {
+ *s = make([]string, arrLen)
+ } else {
+ *s = (*s)[:arrLen]
+ }
+
+ for i := primitive.Uint32(0); i < arrLen; i++ {
+ var sstr SizedString
+ sstr.UnmarshalBytes(src)
+ src = src[sstr.SizeBytes():]
+ (*s)[i] = string(sstr)
+ }
+}
+
+// Inode represents an inode on the remote filesystem.
+//
+// +marshal slice:InodeSlice
+type Inode struct {
+ ControlFD FDID
+ _ uint32 // Need to make struct packed.
+ Stat linux.Statx
+}
+
+// MountReq represents a Mount request.
+type MountReq struct {
+ MountPath SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (m *MountReq) SizeBytes() int {
+ return m.MountPath.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (m *MountReq) MarshalBytes(dst []byte) {
+ m.MountPath.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (m *MountReq) UnmarshalBytes(src []byte) {
+ m.MountPath.UnmarshalBytes(src)
+}
+
+// MountResp represents a Mount response.
+type MountResp struct {
+ Root Inode
+ // MaxMessageSize is the maximum size of messages communicated between the
+ // client and server in bytes. This includes the communication header.
+ MaxMessageSize primitive.Uint32
+ // SupportedMs holds all the supported messages.
+ SupportedMs []MID
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (m *MountResp) SizeBytes() int {
+ return m.Root.SizeBytes() +
+ m.MaxMessageSize.SizeBytes() +
+ (*primitive.Uint16)(nil).SizeBytes() +
+ (len(m.SupportedMs) * (*MID)(nil).SizeBytes())
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (m *MountResp) MarshalBytes(dst []byte) {
+ m.Root.MarshalUnsafe(dst)
+ dst = dst[m.Root.SizeBytes():]
+ m.MaxMessageSize.MarshalUnsafe(dst)
+ dst = dst[m.MaxMessageSize.SizeBytes():]
+ numSupported := primitive.Uint16(len(m.SupportedMs))
+ numSupported.MarshalBytes(dst)
+ dst = dst[numSupported.SizeBytes():]
+ MarshalUnsafeMIDSlice(m.SupportedMs, dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (m *MountResp) UnmarshalBytes(src []byte) {
+ m.Root.UnmarshalUnsafe(src)
+ src = src[m.Root.SizeBytes():]
+ m.MaxMessageSize.UnmarshalUnsafe(src)
+ src = src[m.MaxMessageSize.SizeBytes():]
+ var numSupported primitive.Uint16
+ numSupported.UnmarshalBytes(src)
+ src = src[numSupported.SizeBytes():]
+ m.SupportedMs = make([]MID, numSupported)
+ UnmarshalUnsafeMIDSlice(m.SupportedMs, src)
+}
+
+// ChannelResp is the response to the create channel request.
+//
+// +marshal
+type ChannelResp struct {
+ dataOffset int64
+ dataLength uint64
+}
+
+// ErrorResp is returned to represent an error while handling a request.
+//
+// +marshal
+type ErrorResp struct {
+ errno uint32
+}
+
+// StatReq requests the stat results for the specified FD.
+//
+// +marshal
+type StatReq struct {
+ FD FDID
+}
+
+// SetStatReq is used to set attributeds on FDs.
+//
+// +marshal
+type SetStatReq struct {
+ FD FDID
+ _ uint32
+ Mask uint32
+ Mode uint32 // Only permissions part is settable.
+ UID UID
+ GID GID
+ Size uint64
+ Atime linux.Timespec
+ Mtime linux.Timespec
+}
+
+// SetStatResp is used to communicate SetStat results. It contains a mask
+// representing the failed changes. It also contains the errno of the failed
+// set attribute operation. If multiple operations failed then any of those
+// errnos can be returned.
+//
+// +marshal
+type SetStatResp struct {
+ FailureMask uint32
+ FailureErrNo uint32
+}
+
+// WalkReq is used to request to walk multiple path components at once. This
+// is used for both Walk and WalkStat.
+type WalkReq struct {
+ DirFD FDID
+ Path StringArray
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (w *WalkReq) SizeBytes() int {
+ return w.DirFD.SizeBytes() + w.Path.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (w *WalkReq) MarshalBytes(dst []byte) {
+ w.DirFD.MarshalUnsafe(dst)
+ dst = dst[w.DirFD.SizeBytes():]
+ w.Path.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (w *WalkReq) UnmarshalBytes(src []byte) {
+ w.DirFD.UnmarshalUnsafe(src)
+ src = src[w.DirFD.SizeBytes():]
+ w.Path.UnmarshalBytes(src)
+}
+
+// WalkStatus is used to indicate the reason for partial/unsuccessful server
+// side Walk operations. Please note that partial/unsuccessful walk operations
+// do not necessarily fail the RPC. The RPC is successful with a failure hint
+// which can be used by the client to infer server-side state.
+type WalkStatus = primitive.Uint8
+
+const (
+ // WalkSuccess indicates that all path components were successfully walked.
+ WalkSuccess WalkStatus = iota
+
+ // WalkComponentDoesNotExist indicates that the walk was prematurely
+ // terminated because an intermediate path component does not exist on
+ // server. The results of all previous existing path components is returned.
+ WalkComponentDoesNotExist
+
+ // WalkComponentSymlink indicates that the walk was prematurely
+ // terminated because an intermediate path component was a symlink. It is not
+ // safe to resolve symlinks remotely (unaware of mount points).
+ WalkComponentSymlink
+)
+
+// WalkResp is used to communicate the inodes walked by the server.
+type WalkResp struct {
+ Status WalkStatus
+ Inodes []Inode
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (w *WalkResp) SizeBytes() int {
+ return w.Status.SizeBytes() +
+ (*primitive.Uint32)(nil).SizeBytes() + (len(w.Inodes) * (*Inode)(nil).SizeBytes())
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (w *WalkResp) MarshalBytes(dst []byte) {
+ w.Status.MarshalUnsafe(dst)
+ dst = dst[w.Status.SizeBytes():]
+
+ numInodes := primitive.Uint32(len(w.Inodes))
+ numInodes.MarshalUnsafe(dst)
+ dst = dst[numInodes.SizeBytes():]
+
+ MarshalUnsafeInodeSlice(w.Inodes, dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (w *WalkResp) UnmarshalBytes(src []byte) {
+ w.Status.UnmarshalUnsafe(src)
+ src = src[w.Status.SizeBytes():]
+
+ var numInodes primitive.Uint32
+ numInodes.UnmarshalUnsafe(src)
+ src = src[numInodes.SizeBytes():]
+
+ if cap(w.Inodes) < int(numInodes) {
+ w.Inodes = make([]Inode, numInodes)
+ } else {
+ w.Inodes = w.Inodes[:numInodes]
+ }
+ UnmarshalUnsafeInodeSlice(w.Inodes, src)
+}
+
+// WalkStatResp is used to communicate stat results for WalkStat.
+type WalkStatResp struct {
+ Stats []linux.Statx
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (w *WalkStatResp) SizeBytes() int {
+ return (*primitive.Uint32)(nil).SizeBytes() + (len(w.Stats) * linux.SizeOfStatx)
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (w *WalkStatResp) MarshalBytes(dst []byte) {
+ numStats := primitive.Uint32(len(w.Stats))
+ numStats.MarshalUnsafe(dst)
+ dst = dst[numStats.SizeBytes():]
+
+ linux.MarshalUnsafeStatxSlice(w.Stats, dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (w *WalkStatResp) UnmarshalBytes(src []byte) {
+ var numStats primitive.Uint32
+ numStats.UnmarshalUnsafe(src)
+ src = src[numStats.SizeBytes():]
+
+ if cap(w.Stats) < int(numStats) {
+ w.Stats = make([]linux.Statx, numStats)
+ } else {
+ w.Stats = w.Stats[:numStats]
+ }
+ linux.UnmarshalUnsafeStatxSlice(w.Stats, src)
+}
+
+// OpenAtReq is used to open existing FDs with the specified flags.
+//
+// +marshal
+type OpenAtReq struct {
+ FD FDID
+ Flags uint32
+}
+
+// OpenAtResp is used to communicate the newly created FD.
+//
+// +marshal
+type OpenAtResp struct {
+ NewFD FDID
+}
+
+// +marshal
+type createCommon struct {
+ DirFD FDID
+ Mode linux.FileMode
+ _ uint16 // Need to make struct packed.
+ UID UID
+ GID GID
+}
+
+// OpenCreateAtReq is used to make OpenCreateAt requests.
+type OpenCreateAtReq struct {
+ createCommon
+ Name SizedString
+ Flags primitive.Uint32
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (o *OpenCreateAtReq) SizeBytes() int {
+ return o.createCommon.SizeBytes() + o.Name.SizeBytes() + o.Flags.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (o *OpenCreateAtReq) MarshalBytes(dst []byte) {
+ o.createCommon.MarshalUnsafe(dst)
+ dst = dst[o.createCommon.SizeBytes():]
+ o.Name.MarshalBytes(dst)
+ dst = dst[o.Name.SizeBytes():]
+ o.Flags.MarshalUnsafe(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (o *OpenCreateAtReq) UnmarshalBytes(src []byte) {
+ o.createCommon.UnmarshalUnsafe(src)
+ src = src[o.createCommon.SizeBytes():]
+ o.Name.UnmarshalBytes(src)
+ src = src[o.Name.SizeBytes():]
+ o.Flags.UnmarshalUnsafe(src)
+}
+
+// OpenCreateAtResp is used to communicate successful OpenCreateAt results.
+//
+// +marshal
+type OpenCreateAtResp struct {
+ Child Inode
+ NewFD FDID
+ _ uint32 // Need to make struct packed.
+}
+
+// FdArray is a utility struct which implements a marshallable type for
+// communicating an array of FDIDs. In memory, the array data is preceded by a
+// uint32 denoting the array length.
+type FdArray []FDID
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (f *FdArray) SizeBytes() int {
+ return (*primitive.Uint32)(nil).SizeBytes() + (len(*f) * (*FDID)(nil).SizeBytes())
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (f *FdArray) MarshalBytes(dst []byte) {
+ arrLen := primitive.Uint32(len(*f))
+ arrLen.MarshalUnsafe(dst)
+ dst = dst[arrLen.SizeBytes():]
+ MarshalUnsafeFDIDSlice(*f, dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (f *FdArray) UnmarshalBytes(src []byte) {
+ var arrLen primitive.Uint32
+ arrLen.UnmarshalUnsafe(src)
+ src = src[arrLen.SizeBytes():]
+ if cap(*f) < int(arrLen) {
+ *f = make(FdArray, arrLen)
+ } else {
+ *f = (*f)[:arrLen]
+ }
+ UnmarshalUnsafeFDIDSlice(*f, src)
+}
+
+// CloseReq is used to close(2) FDs.
+type CloseReq struct {
+ FDs FdArray
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (c *CloseReq) SizeBytes() int {
+ return c.FDs.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (c *CloseReq) MarshalBytes(dst []byte) {
+ c.FDs.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (c *CloseReq) UnmarshalBytes(src []byte) {
+ c.FDs.UnmarshalBytes(src)
+}
+
+// FsyncReq is used to fsync(2) FDs.
+type FsyncReq struct {
+ FDs FdArray
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (f *FsyncReq) SizeBytes() int {
+ return f.FDs.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (f *FsyncReq) MarshalBytes(dst []byte) {
+ f.FDs.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (f *FsyncReq) UnmarshalBytes(src []byte) {
+ f.FDs.UnmarshalBytes(src)
+}
+
+// PReadReq is used to pread(2) on an FD.
+//
+// +marshal
+type PReadReq struct {
+ Offset uint64
+ FD FDID
+ Count uint32
+}
+
+// PReadResp is used to return the result of pread(2).
+type PReadResp struct {
+ NumBytes primitive.Uint32
+ Buf []byte
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (r *PReadResp) SizeBytes() int {
+ return r.NumBytes.SizeBytes() + int(r.NumBytes)
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (r *PReadResp) MarshalBytes(dst []byte) {
+ r.NumBytes.MarshalUnsafe(dst)
+ dst = dst[r.NumBytes.SizeBytes():]
+ copy(dst[:r.NumBytes], r.Buf[:r.NumBytes])
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (r *PReadResp) UnmarshalBytes(src []byte) {
+ r.NumBytes.UnmarshalUnsafe(src)
+ src = src[r.NumBytes.SizeBytes():]
+
+ // We expect the client to have already allocated r.Buf. r.Buf probably
+ // (optimally) points to usermem. Directly copy into that.
+ copy(r.Buf[:r.NumBytes], src[:r.NumBytes])
+}
+
+// PWriteReq is used to pwrite(2) on an FD.
+type PWriteReq struct {
+ Offset primitive.Uint64
+ FD FDID
+ NumBytes primitive.Uint32
+ Buf []byte
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (w *PWriteReq) SizeBytes() int {
+ return w.Offset.SizeBytes() + w.FD.SizeBytes() + w.NumBytes.SizeBytes() + int(w.NumBytes)
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (w *PWriteReq) MarshalBytes(dst []byte) {
+ w.Offset.MarshalUnsafe(dst)
+ dst = dst[w.Offset.SizeBytes():]
+ w.FD.MarshalUnsafe(dst)
+ dst = dst[w.FD.SizeBytes():]
+ w.NumBytes.MarshalUnsafe(dst)
+ dst = dst[w.NumBytes.SizeBytes():]
+ copy(dst[:w.NumBytes], w.Buf[:w.NumBytes])
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (w *PWriteReq) UnmarshalBytes(src []byte) {
+ w.Offset.UnmarshalUnsafe(src)
+ src = src[w.Offset.SizeBytes():]
+ w.FD.UnmarshalUnsafe(src)
+ src = src[w.FD.SizeBytes():]
+ w.NumBytes.UnmarshalUnsafe(src)
+ src = src[w.NumBytes.SizeBytes():]
+
+ // This is an optimization. Assuming that the server is making this call, it
+ // is safe to just point to src rather than allocating and copying.
+ w.Buf = src[:w.NumBytes]
+}
+
+// PWriteResp is used to return the result of pwrite(2).
+//
+// +marshal
+type PWriteResp struct {
+ Count uint64
+}
+
+// MkdirAtReq is used to make MkdirAt requests.
+type MkdirAtReq struct {
+ createCommon
+ Name SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (m *MkdirAtReq) SizeBytes() int {
+ return m.createCommon.SizeBytes() + m.Name.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (m *MkdirAtReq) MarshalBytes(dst []byte) {
+ m.createCommon.MarshalUnsafe(dst)
+ dst = dst[m.createCommon.SizeBytes():]
+ m.Name.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (m *MkdirAtReq) UnmarshalBytes(src []byte) {
+ m.createCommon.UnmarshalUnsafe(src)
+ src = src[m.createCommon.SizeBytes():]
+ m.Name.UnmarshalBytes(src)
+}
+
+// MkdirAtResp is the response to a successful MkdirAt request.
+//
+// +marshal
+type MkdirAtResp struct {
+ ChildDir Inode
+}
+
+// MknodAtReq is used to make MknodAt requests.
+type MknodAtReq struct {
+ createCommon
+ Name SizedString
+ Minor primitive.Uint32
+ Major primitive.Uint32
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (m *MknodAtReq) SizeBytes() int {
+ return m.createCommon.SizeBytes() + m.Name.SizeBytes() + m.Minor.SizeBytes() + m.Major.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (m *MknodAtReq) MarshalBytes(dst []byte) {
+ m.createCommon.MarshalUnsafe(dst)
+ dst = dst[m.createCommon.SizeBytes():]
+ m.Name.MarshalBytes(dst)
+ dst = dst[m.Name.SizeBytes():]
+ m.Minor.MarshalUnsafe(dst)
+ dst = dst[m.Minor.SizeBytes():]
+ m.Major.MarshalUnsafe(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (m *MknodAtReq) UnmarshalBytes(src []byte) {
+ m.createCommon.UnmarshalUnsafe(src)
+ src = src[m.createCommon.SizeBytes():]
+ m.Name.UnmarshalBytes(src)
+ src = src[m.Name.SizeBytes():]
+ m.Minor.UnmarshalUnsafe(src)
+ src = src[m.Minor.SizeBytes():]
+ m.Major.UnmarshalUnsafe(src)
+}
+
+// MknodAtResp is the response to a successful MknodAt request.
+//
+// +marshal
+type MknodAtResp struct {
+ Child Inode
+}
+
+// SymlinkAtReq is used to make SymlinkAt request.
+type SymlinkAtReq struct {
+ DirFD FDID
+ Name SizedString
+ Target SizedString
+ UID UID
+ GID GID
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (s *SymlinkAtReq) SizeBytes() int {
+ return s.DirFD.SizeBytes() + s.Name.SizeBytes() + s.Target.SizeBytes() + s.UID.SizeBytes() + s.GID.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (s *SymlinkAtReq) MarshalBytes(dst []byte) {
+ s.DirFD.MarshalUnsafe(dst)
+ dst = dst[s.DirFD.SizeBytes():]
+ s.Name.MarshalBytes(dst)
+ dst = dst[s.Name.SizeBytes():]
+ s.Target.MarshalBytes(dst)
+ dst = dst[s.Target.SizeBytes():]
+ s.UID.MarshalUnsafe(dst)
+ dst = dst[s.UID.SizeBytes():]
+ s.GID.MarshalUnsafe(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (s *SymlinkAtReq) UnmarshalBytes(src []byte) {
+ s.DirFD.UnmarshalUnsafe(src)
+ src = src[s.DirFD.SizeBytes():]
+ s.Name.UnmarshalBytes(src)
+ src = src[s.Name.SizeBytes():]
+ s.Target.UnmarshalBytes(src)
+ src = src[s.Target.SizeBytes():]
+ s.UID.UnmarshalUnsafe(src)
+ src = src[s.UID.SizeBytes():]
+ s.GID.UnmarshalUnsafe(src)
+}
+
+// SymlinkAtResp is the response to a successful SymlinkAt request.
+//
+// +marshal
+type SymlinkAtResp struct {
+ Symlink Inode
+}
+
+// LinkAtReq is used to make LinkAt requests.
+type LinkAtReq struct {
+ DirFD FDID
+ Target FDID
+ Name SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (l *LinkAtReq) SizeBytes() int {
+ return l.DirFD.SizeBytes() + l.Target.SizeBytes() + l.Name.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (l *LinkAtReq) MarshalBytes(dst []byte) {
+ l.DirFD.MarshalUnsafe(dst)
+ dst = dst[l.DirFD.SizeBytes():]
+ l.Target.MarshalUnsafe(dst)
+ dst = dst[l.Target.SizeBytes():]
+ l.Name.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (l *LinkAtReq) UnmarshalBytes(src []byte) {
+ l.DirFD.UnmarshalUnsafe(src)
+ src = src[l.DirFD.SizeBytes():]
+ l.Target.UnmarshalUnsafe(src)
+ src = src[l.Target.SizeBytes():]
+ l.Name.UnmarshalBytes(src)
+}
+
+// LinkAtResp is used to respond to a successful LinkAt request.
+//
+// +marshal
+type LinkAtResp struct {
+ Link Inode
+}
+
+// FStatFSReq is used to request StatFS results for the specified FD.
+//
+// +marshal
+type FStatFSReq struct {
+ FD FDID
+}
+
+// StatFS is responded to a successful FStatFS request.
+//
+// +marshal
+type StatFS struct {
+ Type uint64
+ BlockSize int64
+ Blocks uint64
+ BlocksFree uint64
+ BlocksAvailable uint64
+ Files uint64
+ FilesFree uint64
+ NameLength uint64
+}
+
+// FAllocateReq is used to request to fallocate(2) an FD. This has no response.
+//
+// +marshal
+type FAllocateReq struct {
+ FD FDID
+ _ uint32
+ Mode uint64
+ Offset uint64
+ Length uint64
+}
+
+// ReadLinkAtReq is used to readlinkat(2) at the specified FD.
+//
+// +marshal
+type ReadLinkAtReq struct {
+ FD FDID
+}
+
+// ReadLinkAtResp is used to communicate ReadLinkAt results.
+type ReadLinkAtResp struct {
+ Target SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (r *ReadLinkAtResp) SizeBytes() int {
+ return r.Target.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (r *ReadLinkAtResp) MarshalBytes(dst []byte) {
+ r.Target.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (r *ReadLinkAtResp) UnmarshalBytes(src []byte) {
+ r.Target.UnmarshalBytes(src)
+}
+
+// FlushReq is used to make Flush requests.
+//
+// +marshal
+type FlushReq struct {
+ FD FDID
+}
+
+// ConnectReq is used to make a Connect request.
+//
+// +marshal
+type ConnectReq struct {
+ FD FDID
+ // SockType is used to specify the socket type to connect to. As a special
+ // case, SockType = 0 means that the socket type does not matter and the
+ // requester will accept any socket type.
+ SockType uint32
+}
+
+// UnlinkAtReq is used to make UnlinkAt request.
+type UnlinkAtReq struct {
+ DirFD FDID
+ Name SizedString
+ Flags primitive.Uint32
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (u *UnlinkAtReq) SizeBytes() int {
+ return u.DirFD.SizeBytes() + u.Name.SizeBytes() + u.Flags.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (u *UnlinkAtReq) MarshalBytes(dst []byte) {
+ u.DirFD.MarshalUnsafe(dst)
+ dst = dst[u.DirFD.SizeBytes():]
+ u.Name.MarshalBytes(dst)
+ dst = dst[u.Name.SizeBytes():]
+ u.Flags.MarshalUnsafe(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (u *UnlinkAtReq) UnmarshalBytes(src []byte) {
+ u.DirFD.UnmarshalUnsafe(src)
+ src = src[u.DirFD.SizeBytes():]
+ u.Name.UnmarshalBytes(src)
+ src = src[u.Name.SizeBytes():]
+ u.Flags.UnmarshalUnsafe(src)
+}
+
+// RenameAtReq is used to make Rename requests. Note that the request takes in
+// the to-be-renamed file's FD instead of oldDir and oldName like renameat(2).
+type RenameAtReq struct {
+ Renamed FDID
+ NewDir FDID
+ NewName SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (r *RenameAtReq) SizeBytes() int {
+ return r.Renamed.SizeBytes() + r.NewDir.SizeBytes() + r.NewName.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (r *RenameAtReq) MarshalBytes(dst []byte) {
+ r.Renamed.MarshalUnsafe(dst)
+ dst = dst[r.Renamed.SizeBytes():]
+ r.NewDir.MarshalUnsafe(dst)
+ dst = dst[r.NewDir.SizeBytes():]
+ r.NewName.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (r *RenameAtReq) UnmarshalBytes(src []byte) {
+ r.Renamed.UnmarshalUnsafe(src)
+ src = src[r.Renamed.SizeBytes():]
+ r.NewDir.UnmarshalUnsafe(src)
+ src = src[r.NewDir.SizeBytes():]
+ r.NewName.UnmarshalBytes(src)
+}
+
+// Getdents64Req is used to make Getdents64 requests.
+//
+// +marshal
+type Getdents64Req struct {
+ DirFD FDID
+ // Count is the number of bytes to read. A negative value of Count is used to
+ // indicate that the implementation must lseek(0, SEEK_SET) before calling
+ // getdents64(2). Implementations must use the absolute value of Count to
+ // determine the number of bytes to read.
+ Count int32
+}
+
+// Dirent64 is analogous to struct linux_dirent64.
+type Dirent64 struct {
+ Ino primitive.Uint64
+ DevMinor primitive.Uint32
+ DevMajor primitive.Uint32
+ Off primitive.Uint64
+ Type primitive.Uint8
+ Name SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (d *Dirent64) SizeBytes() int {
+ return d.Ino.SizeBytes() + d.DevMinor.SizeBytes() + d.DevMajor.SizeBytes() + d.Off.SizeBytes() + d.Type.SizeBytes() + d.Name.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (d *Dirent64) MarshalBytes(dst []byte) {
+ d.Ino.MarshalUnsafe(dst)
+ dst = dst[d.Ino.SizeBytes():]
+ d.DevMinor.MarshalUnsafe(dst)
+ dst = dst[d.DevMinor.SizeBytes():]
+ d.DevMajor.MarshalUnsafe(dst)
+ dst = dst[d.DevMajor.SizeBytes():]
+ d.Off.MarshalUnsafe(dst)
+ dst = dst[d.Off.SizeBytes():]
+ d.Type.MarshalUnsafe(dst)
+ dst = dst[d.Type.SizeBytes():]
+ d.Name.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (d *Dirent64) UnmarshalBytes(src []byte) {
+ d.Ino.UnmarshalUnsafe(src)
+ src = src[d.Ino.SizeBytes():]
+ d.DevMinor.UnmarshalUnsafe(src)
+ src = src[d.DevMinor.SizeBytes():]
+ d.DevMajor.UnmarshalUnsafe(src)
+ src = src[d.DevMajor.SizeBytes():]
+ d.Off.UnmarshalUnsafe(src)
+ src = src[d.Off.SizeBytes():]
+ d.Type.UnmarshalUnsafe(src)
+ src = src[d.Type.SizeBytes():]
+ d.Name.UnmarshalBytes(src)
+}
+
+// Getdents64Resp is used to communicate getdents64 results.
+type Getdents64Resp struct {
+ Dirents []Dirent64
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (g *Getdents64Resp) SizeBytes() int {
+ ret := (*primitive.Uint32)(nil).SizeBytes()
+ for i := range g.Dirents {
+ ret += g.Dirents[i].SizeBytes()
+ }
+ return ret
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (g *Getdents64Resp) MarshalBytes(dst []byte) {
+ numDirents := primitive.Uint32(len(g.Dirents))
+ numDirents.MarshalUnsafe(dst)
+ dst = dst[numDirents.SizeBytes():]
+ for i := range g.Dirents {
+ g.Dirents[i].MarshalBytes(dst)
+ dst = dst[g.Dirents[i].SizeBytes():]
+ }
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (g *Getdents64Resp) UnmarshalBytes(src []byte) {
+ var numDirents primitive.Uint32
+ numDirents.UnmarshalUnsafe(src)
+ if cap(g.Dirents) < int(numDirents) {
+ g.Dirents = make([]Dirent64, numDirents)
+ } else {
+ g.Dirents = g.Dirents[:numDirents]
+ }
+
+ src = src[numDirents.SizeBytes():]
+ for i := range g.Dirents {
+ g.Dirents[i].UnmarshalBytes(src)
+ src = src[g.Dirents[i].SizeBytes():]
+ }
+}
+
+// FGetXattrReq is used to make FGetXattr requests. The response to this is
+// just a SizedString containing the xattr value.
+type FGetXattrReq struct {
+ FD FDID
+ BufSize primitive.Uint32
+ Name SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (g *FGetXattrReq) SizeBytes() int {
+ return g.FD.SizeBytes() + g.BufSize.SizeBytes() + g.Name.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (g *FGetXattrReq) MarshalBytes(dst []byte) {
+ g.FD.MarshalUnsafe(dst)
+ dst = dst[g.FD.SizeBytes():]
+ g.BufSize.MarshalUnsafe(dst)
+ dst = dst[g.BufSize.SizeBytes():]
+ g.Name.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (g *FGetXattrReq) UnmarshalBytes(src []byte) {
+ g.FD.UnmarshalUnsafe(src)
+ src = src[g.FD.SizeBytes():]
+ g.BufSize.UnmarshalUnsafe(src)
+ src = src[g.BufSize.SizeBytes():]
+ g.Name.UnmarshalBytes(src)
+}
+
+// FGetXattrResp is used to respond to FGetXattr request.
+type FGetXattrResp struct {
+ Value SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (g *FGetXattrResp) SizeBytes() int {
+ return g.Value.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (g *FGetXattrResp) MarshalBytes(dst []byte) {
+ g.Value.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (g *FGetXattrResp) UnmarshalBytes(src []byte) {
+ g.Value.UnmarshalBytes(src)
+}
+
+// FSetXattrReq is used to make FSetXattr requests. It has no response.
+type FSetXattrReq struct {
+ FD FDID
+ Flags primitive.Uint32
+ Name SizedString
+ Value SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (s *FSetXattrReq) SizeBytes() int {
+ return s.FD.SizeBytes() + s.Flags.SizeBytes() + s.Name.SizeBytes() + s.Value.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (s *FSetXattrReq) MarshalBytes(dst []byte) {
+ s.FD.MarshalUnsafe(dst)
+ dst = dst[s.FD.SizeBytes():]
+ s.Flags.MarshalUnsafe(dst)
+ dst = dst[s.Flags.SizeBytes():]
+ s.Name.MarshalBytes(dst)
+ dst = dst[s.Name.SizeBytes():]
+ s.Value.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (s *FSetXattrReq) UnmarshalBytes(src []byte) {
+ s.FD.UnmarshalUnsafe(src)
+ src = src[s.FD.SizeBytes():]
+ s.Flags.UnmarshalUnsafe(src)
+ src = src[s.Flags.SizeBytes():]
+ s.Name.UnmarshalBytes(src)
+ src = src[s.Name.SizeBytes():]
+ s.Value.UnmarshalBytes(src)
+}
+
+// FRemoveXattrReq is used to make FRemoveXattr requests. It has no response.
+type FRemoveXattrReq struct {
+ FD FDID
+ Name SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (r *FRemoveXattrReq) SizeBytes() int {
+ return r.FD.SizeBytes() + r.Name.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (r *FRemoveXattrReq) MarshalBytes(dst []byte) {
+ r.FD.MarshalUnsafe(dst)
+ dst = dst[r.FD.SizeBytes():]
+ r.Name.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (r *FRemoveXattrReq) UnmarshalBytes(src []byte) {
+ r.FD.UnmarshalUnsafe(src)
+ src = src[r.FD.SizeBytes():]
+ r.Name.UnmarshalBytes(src)
+}
+
+// FListXattrReq is used to make FListXattr requests.
+//
+// +marshal
+type FListXattrReq struct {
+ FD FDID
+ _ uint32
+ Size uint64
+}
+
+// FListXattrResp is used to respond to FListXattr requests.
+type FListXattrResp struct {
+ Xattrs StringArray
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (l *FListXattrResp) SizeBytes() int {
+ return l.Xattrs.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (l *FListXattrResp) MarshalBytes(dst []byte) {
+ l.Xattrs.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (l *FListXattrResp) UnmarshalBytes(src []byte) {
+ l.Xattrs.UnmarshalBytes(src)
+}
diff --git a/pkg/lisafs/sample_message.go b/pkg/lisafs/sample_message.go
new file mode 100644
index 000000000..3868dfa08
--- /dev/null
+++ b/pkg/lisafs/sample_message.go
@@ -0,0 +1,110 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "math/rand"
+
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
+)
+
+// MsgSimple is a sample packed struct which can be used to test message passing.
+//
+// +marshal slice:Msg1Slice
+type MsgSimple struct {
+ A uint16
+ B uint16
+ C uint32
+ D uint64
+}
+
+// Randomize randomizes the contents of m.
+func (m *MsgSimple) Randomize() {
+ m.A = uint16(rand.Uint32())
+ m.B = uint16(rand.Uint32())
+ m.C = rand.Uint32()
+ m.D = rand.Uint64()
+}
+
+// MsgDynamic is a sample dynamic struct which can be used to test message passing.
+//
+// +marshal dynamic
+type MsgDynamic struct {
+ N primitive.Uint32
+ Arr []MsgSimple
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (m *MsgDynamic) SizeBytes() int {
+ return m.N.SizeBytes() +
+ (int(m.N) * (*MsgSimple)(nil).SizeBytes())
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (m *MsgDynamic) MarshalBytes(dst []byte) {
+ m.N.MarshalUnsafe(dst)
+ dst = dst[m.N.SizeBytes():]
+ MarshalUnsafeMsg1Slice(m.Arr, dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (m *MsgDynamic) UnmarshalBytes(src []byte) {
+ m.N.UnmarshalUnsafe(src)
+ src = src[m.N.SizeBytes():]
+ m.Arr = make([]MsgSimple, m.N)
+ UnmarshalUnsafeMsg1Slice(m.Arr, src)
+}
+
+// Randomize randomizes the contents of m.
+func (m *MsgDynamic) Randomize(arrLen int) {
+ m.N = primitive.Uint32(arrLen)
+ m.Arr = make([]MsgSimple, arrLen)
+ for i := 0; i < arrLen; i++ {
+ m.Arr[i].Randomize()
+ }
+}
+
+// P9Version mimics p9.TVersion and p9.Rversion.
+//
+// +marshal dynamic
+type P9Version struct {
+ MSize primitive.Uint32
+ Version string
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (v *P9Version) SizeBytes() int {
+ return (*primitive.Uint32)(nil).SizeBytes() + (*primitive.Uint16)(nil).SizeBytes() + len(v.Version)
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (v *P9Version) MarshalBytes(dst []byte) {
+ v.MSize.MarshalUnsafe(dst)
+ dst = dst[v.MSize.SizeBytes():]
+ versionLen := primitive.Uint16(len(v.Version))
+ versionLen.MarshalUnsafe(dst)
+ dst = dst[versionLen.SizeBytes():]
+ copy(dst, v.Version)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (v *P9Version) UnmarshalBytes(src []byte) {
+ v.MSize.UnmarshalUnsafe(src)
+ src = src[v.MSize.SizeBytes():]
+ var versionLen primitive.Uint16
+ versionLen.UnmarshalUnsafe(src)
+ src = src[versionLen.SizeBytes():]
+ v.Version = string(src[:versionLen])
+}
diff --git a/pkg/lisafs/server.go b/pkg/lisafs/server.go
new file mode 100644
index 000000000..7515355ec
--- /dev/null
+++ b/pkg/lisafs/server.go
@@ -0,0 +1,113 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Server serves a filesystem tree. Multiple connections on different mount
+// points can be started on a server. The server provides utilities to safely
+// modify the filesystem tree across its connections (mount points). Note that
+// it does not support synchronizing filesystem tree mutations across other
+// servers serving the same filesystem subtree. Server also manages the
+// lifecycle of all connections.
+type Server struct {
+ // connWg counts the number of active connections being tracked.
+ connWg sync.WaitGroup
+
+ // RenameMu synchronizes rename operations within this filesystem tree.
+ RenameMu sync.RWMutex
+
+ // handlers is a list of RPC handlers which can be indexed by the handler's
+ // corresponding MID.
+ handlers []RPCHandler
+
+ // mountPoints keeps track of all the mount points this server serves.
+ mpMu sync.RWMutex
+ mountPoints []*ControlFD
+
+ // impl is the server implementation which embeds this server.
+ impl ServerImpl
+}
+
+// Init must be called before first use of server.
+func (s *Server) Init(impl ServerImpl) {
+ s.impl = impl
+ s.handlers = handlers[:]
+}
+
+// InitTestOnly is the same as Init except that it allows to swap out the
+// underlying handlers with something custom. This is for test only.
+func (s *Server) InitTestOnly(impl ServerImpl, handlers []RPCHandler) {
+ s.impl = impl
+ s.handlers = handlers
+}
+
+// WithRenameReadLock invokes fn with the server's rename mutex locked for
+// reading. This ensures that no rename operations occur concurrently.
+func (s *Server) WithRenameReadLock(fn func() error) error {
+ s.RenameMu.RLock()
+ err := fn()
+ s.RenameMu.RUnlock()
+ return err
+}
+
+// StartConnection starts the connection on a separate goroutine and tracks it.
+func (s *Server) StartConnection(c *Connection) {
+ s.connWg.Add(1)
+ go func() {
+ c.Run()
+ s.connWg.Done()
+ }()
+}
+
+// Wait waits for all connections started via StartConnection() to terminate.
+func (s *Server) Wait() {
+ s.connWg.Wait()
+}
+
+func (s *Server) addMountPoint(root *ControlFD) {
+ s.mpMu.Lock()
+ defer s.mpMu.Unlock()
+ s.mountPoints = append(s.mountPoints, root)
+}
+
+func (s *Server) forEachMountPoint(fn func(root *ControlFD)) {
+ s.mpMu.RLock()
+ defer s.mpMu.RUnlock()
+ for _, mp := range s.mountPoints {
+ fn(mp)
+ }
+}
+
+// ServerImpl contains the implementation details for a Server.
+// Implementations of ServerImpl should contain their associated Server by
+// value as their first field.
+type ServerImpl interface {
+ // Mount is called when a Mount RPC is made. It mounts the connection at
+ // mountPath.
+ //
+ // Precondition: mountPath == path.Clean(mountPath).
+ Mount(c *Connection, mountPath string) (ControlFDImpl, Inode, error)
+
+ // SupportedMessages returns a list of messages that the server
+ // implementation supports.
+ SupportedMessages() []MID
+
+ // MaxMessageSize is the maximum payload length (in bytes) that can be sent
+ // to this server implementation.
+ MaxMessageSize() uint32
+}
diff --git a/pkg/lisafs/sock.go b/pkg/lisafs/sock.go
new file mode 100644
index 000000000..88210242f
--- /dev/null
+++ b/pkg/lisafs/sock.go
@@ -0,0 +1,208 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "io"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+var (
+ sockHeaderLen = uint32((*sockHeader)(nil).SizeBytes())
+)
+
+// sockHeader is the header present in front of each message received on a UDS.
+//
+// +marshal
+type sockHeader struct {
+ payloadLen uint32
+ message MID
+ _ uint16 // Need to make struct packed.
+}
+
+// sockCommunicator implements Communicator. This is not thread safe.
+type sockCommunicator struct {
+ fdTracker
+ sock *unet.Socket
+ buf []byte
+}
+
+var _ Communicator = (*sockCommunicator)(nil)
+
+func newSockComm(sock *unet.Socket) *sockCommunicator {
+ return &sockCommunicator{
+ sock: sock,
+ buf: make([]byte, sockHeaderLen),
+ }
+}
+
+func (s *sockCommunicator) FD() int {
+ return s.sock.FD()
+}
+
+func (s *sockCommunicator) destroy() {
+ s.sock.Close()
+}
+
+func (s *sockCommunicator) shutdown() {
+ if err := s.sock.Shutdown(); err != nil {
+ log.Warningf("Socket.Shutdown() failed (FD: %d): %v", s.sock.FD(), err)
+ }
+}
+
+func (s *sockCommunicator) resizeBuf(size uint32) {
+ if cap(s.buf) < int(size) {
+ s.buf = s.buf[:cap(s.buf)]
+ s.buf = append(s.buf, make([]byte, int(size)-cap(s.buf))...)
+ } else {
+ s.buf = s.buf[:size]
+ }
+}
+
+// PayloadBuf implements Communicator.PayloadBuf.
+func (s *sockCommunicator) PayloadBuf(size uint32) []byte {
+ s.resizeBuf(sockHeaderLen + size)
+ return s.buf[sockHeaderLen : sockHeaderLen+size]
+}
+
+// SndRcvMessage implements Communicator.SndRcvMessage.
+func (s *sockCommunicator) SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) {
+ if err := s.sndPrepopulatedMsg(m, payloadLen, nil); err != nil {
+ return 0, 0, err
+ }
+
+ return s.rcvMsg(wantFDs)
+}
+
+// sndPrepopulatedMsg assumes that s.buf has already been populated with
+// `payloadLen` bytes of data.
+func (s *sockCommunicator) sndPrepopulatedMsg(m MID, payloadLen uint32, fds []int) error {
+ header := sockHeader{payloadLen: payloadLen, message: m}
+ header.MarshalUnsafe(s.buf)
+ dataLen := sockHeaderLen + payloadLen
+ return writeTo(s.sock, [][]byte{s.buf[:dataLen]}, int(dataLen), fds)
+}
+
+// writeTo writes the passed iovec to the UDS and donates any passed FDs.
+func writeTo(sock *unet.Socket, iovec [][]byte, dataLen int, fds []int) error {
+ w := sock.Writer(true)
+ if len(fds) > 0 {
+ w.PackFDs(fds...)
+ }
+
+ fdsUnpacked := false
+ for n := 0; n < dataLen; {
+ cur, err := w.WriteVec(iovec)
+ if err != nil {
+ return err
+ }
+ n += cur
+
+ // Fast common path.
+ if n >= dataLen {
+ break
+ }
+
+ // Consume iovecs.
+ for consumed := 0; consumed < cur; {
+ if len(iovec[0]) <= cur-consumed {
+ consumed += len(iovec[0])
+ iovec = iovec[1:]
+ } else {
+ iovec[0] = iovec[0][cur-consumed:]
+ break
+ }
+ }
+
+ if n > 0 && !fdsUnpacked {
+ // Don't resend any control message.
+ fdsUnpacked = true
+ w.UnpackFDs()
+ }
+ }
+ return nil
+}
+
+// rcvMsg reads the message header and payload from the UDS. It also populates
+// fds with any donated FDs.
+func (s *sockCommunicator) rcvMsg(wantFDs uint8) (MID, uint32, error) {
+ fds, err := readFrom(s.sock, s.buf[:sockHeaderLen], wantFDs)
+ if err != nil {
+ return 0, 0, err
+ }
+ for _, fd := range fds {
+ s.TrackFD(fd)
+ }
+
+ var header sockHeader
+ header.UnmarshalUnsafe(s.buf)
+
+ // No payload? We are done.
+ if header.payloadLen == 0 {
+ return header.message, 0, nil
+ }
+
+ if _, err := readFrom(s.sock, s.PayloadBuf(header.payloadLen), 0); err != nil {
+ return 0, 0, err
+ }
+
+ return header.message, header.payloadLen, nil
+}
+
+// readFrom fills the passed buffer with data from the socket. It also returns
+// any donated FDs.
+func readFrom(sock *unet.Socket, buf []byte, wantFDs uint8) ([]int, error) {
+ r := sock.Reader(true)
+ r.EnableFDs(int(wantFDs))
+
+ var (
+ fds []int
+ fdInit bool
+ )
+ n := len(buf)
+ for got := 0; got < n; {
+ cur, err := r.ReadVec([][]byte{buf[got:]})
+
+ // Ignore EOF if cur > 0.
+ if err != nil && (err != io.EOF || cur == 0) {
+ r.CloseFDs()
+ return nil, err
+ }
+
+ if !fdInit && cur > 0 {
+ fds, err = r.ExtractFDs()
+ if err != nil {
+ return nil, err
+ }
+
+ fdInit = true
+ r.EnableFDs(0)
+ }
+
+ got += cur
+ }
+ return fds, nil
+}
+
+func closeFDs(fds []int) {
+ for _, fd := range fds {
+ if fd >= 0 {
+ unix.Close(fd)
+ }
+ }
+}
diff --git a/pkg/lisafs/sock_test.go b/pkg/lisafs/sock_test.go
new file mode 100644
index 000000000..387f4b7a8
--- /dev/null
+++ b/pkg/lisafs/sock_test.go
@@ -0,0 +1,217 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "bytes"
+ "math/rand"
+ "reflect"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+func runSocketTest(t *testing.T, fun1 func(*sockCommunicator), fun2 func(*sockCommunicator)) {
+ sock1, sock2, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer sock1.Close()
+ defer sock2.Close()
+
+ var testWg sync.WaitGroup
+ testWg.Add(2)
+
+ go func() {
+ fun1(newSockComm(sock1))
+ testWg.Done()
+ }()
+
+ go func() {
+ fun2(newSockComm(sock2))
+ testWg.Done()
+ }()
+
+ testWg.Wait()
+}
+
+func TestReadWrite(t *testing.T) {
+ // Create random data to send.
+ n := 10000
+ data := make([]byte, n)
+ if _, err := rand.Read(data); err != nil {
+ t.Fatalf("rand.Read(data) failed: %v", err)
+ }
+
+ runSocketTest(t, func(comm *sockCommunicator) {
+ // Scatter that data into two parts using Iovecs while sending.
+ mid := n / 2
+ if err := writeTo(comm.sock, [][]byte{data[:mid], data[mid:]}, n, nil); err != nil {
+ t.Errorf("writeTo socket failed: %v", err)
+ }
+ }, func(comm *sockCommunicator) {
+ gotData := make([]byte, n)
+ if _, err := readFrom(comm.sock, gotData, 0); err != nil {
+ t.Fatalf("reading from socket failed: %v", err)
+ }
+
+ // Make sure we got the right data.
+ if res := bytes.Compare(data, gotData); res != 0 {
+ t.Errorf("data received differs from data sent, want = %v, got = %v", data, gotData)
+ }
+ })
+}
+
+func TestFDDonation(t *testing.T) {
+ n := 10
+ data := make([]byte, n)
+ if _, err := rand.Read(data); err != nil {
+ t.Fatalf("rand.Read(data) failed: %v", err)
+ }
+
+ // Try donating FDs to these files.
+ path1 := "/dev/null"
+ path2 := "/dev"
+ path3 := "/dev/random"
+
+ runSocketTest(t, func(comm *sockCommunicator) {
+ devNullFD, err := unix.Open(path1, unix.O_RDONLY, 0)
+ defer unix.Close(devNullFD)
+ if err != nil {
+ t.Fatalf("open(%s) failed: %v", path1, err)
+ }
+ devFD, err := unix.Open(path2, unix.O_RDONLY, 0)
+ defer unix.Close(devFD)
+ if err != nil {
+ t.Fatalf("open(%s) failed: %v", path2, err)
+ }
+ devRandomFD, err := unix.Open(path3, unix.O_RDONLY, 0)
+ defer unix.Close(devRandomFD)
+ if err != nil {
+ t.Fatalf("open(%s) failed: %v", path2, err)
+ }
+ if err := writeTo(comm.sock, [][]byte{data}, n, []int{devNullFD, devFD, devRandomFD}); err != nil {
+ t.Errorf("writeTo socket failed: %v", err)
+ }
+ }, func(comm *sockCommunicator) {
+ gotData := make([]byte, n)
+ fds, err := readFrom(comm.sock, gotData, 3)
+ if err != nil {
+ t.Fatalf("reading from socket failed: %v", err)
+ }
+ defer closeFDs(fds[:])
+
+ if res := bytes.Compare(data, gotData); res != 0 {
+ t.Errorf("data received differs from data sent, want = %v, got = %v", data, gotData)
+ }
+
+ if len(fds) != 3 {
+ t.Fatalf("wanted 3 FD, got %d", len(fds))
+ }
+
+ // Check that the FDs actually point to the correct file.
+ compareFDWithFile(t, fds[0], path1)
+ compareFDWithFile(t, fds[1], path2)
+ compareFDWithFile(t, fds[2], path3)
+ })
+}
+
+func compareFDWithFile(t *testing.T, fd int, path string) {
+ var want unix.Stat_t
+ if err := unix.Stat(path, &want); err != nil {
+ t.Fatalf("stat(%s) failed: %v", path, err)
+ }
+
+ var got unix.Stat_t
+ if err := unix.Fstat(fd, &got); err != nil {
+ t.Fatalf("fstat on donated FD failed: %v", err)
+ }
+
+ if got.Ino != want.Ino || got.Dev != want.Dev {
+ t.Errorf("FD does not point to %s, want = %+v, got = %+v", path, want, got)
+ }
+}
+
+func testSndMsg(comm *sockCommunicator, m MID, msg marshal.Marshallable) error {
+ var payloadLen uint32
+ if msg != nil {
+ payloadLen = uint32(msg.SizeBytes())
+ msg.MarshalUnsafe(comm.PayloadBuf(payloadLen))
+ }
+ return comm.sndPrepopulatedMsg(m, payloadLen, nil)
+}
+
+func TestSndRcvMessage(t *testing.T) {
+ req := &MsgSimple{}
+ req.Randomize()
+ reqM := MID(1)
+
+ // Create a massive random response.
+ var resp MsgDynamic
+ resp.Randomize(100)
+ respM := MID(2)
+
+ runSocketTest(t, func(comm *sockCommunicator) {
+ if err := testSndMsg(comm, reqM, req); err != nil {
+ t.Errorf("writeMessageTo failed: %v", err)
+ }
+ checkMessageReceive(t, comm, respM, &resp)
+ }, func(comm *sockCommunicator) {
+ checkMessageReceive(t, comm, reqM, req)
+ if err := testSndMsg(comm, respM, &resp); err != nil {
+ t.Errorf("writeMessageTo failed: %v", err)
+ }
+ })
+}
+
+func TestSndRcvMessageNoPayload(t *testing.T) {
+ reqM := MID(1)
+ respM := MID(2)
+ runSocketTest(t, func(comm *sockCommunicator) {
+ if err := testSndMsg(comm, reqM, nil); err != nil {
+ t.Errorf("writeMessageTo failed: %v", err)
+ }
+ checkMessageReceive(t, comm, respM, nil)
+ }, func(comm *sockCommunicator) {
+ checkMessageReceive(t, comm, reqM, nil)
+ if err := testSndMsg(comm, respM, nil); err != nil {
+ t.Errorf("writeMessageTo failed: %v", err)
+ }
+ })
+}
+
+func checkMessageReceive(t *testing.T, comm *sockCommunicator, wantM MID, wantMsg marshal.Marshallable) {
+ gotM, payloadLen, err := comm.rcvMsg(0)
+ if err != nil {
+ t.Fatalf("readMessageFrom failed: %v", err)
+ }
+ if gotM != wantM {
+ t.Errorf("got incorrect message ID: got = %d, want = %d", gotM, wantM)
+ }
+ if wantMsg == nil {
+ if payloadLen != 0 {
+ t.Errorf("no payload expect but got %d bytes", payloadLen)
+ }
+ } else {
+ gotMsg := reflect.New(reflect.ValueOf(wantMsg).Elem().Type()).Interface().(marshal.Marshallable)
+ gotMsg.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+ if !reflect.DeepEqual(wantMsg, gotMsg) {
+ t.Errorf("msg differs: want = %+v, got = %+v", wantMsg, gotMsg)
+ }
+ }
+}
diff --git a/pkg/lisafs/testsuite/BUILD b/pkg/lisafs/testsuite/BUILD
new file mode 100644
index 000000000..b4a542b3a
--- /dev/null
+++ b/pkg/lisafs/testsuite/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"],
+)
+
+go_library(
+ name = "testsuite",
+ testonly = True,
+ srcs = ["testsuite.go"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/lisafs",
+ "//pkg/unet",
+ "@com_github_syndtr_gocapability//capability:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/lisafs/testsuite/testsuite.go b/pkg/lisafs/testsuite/testsuite.go
new file mode 100644
index 000000000..5fc7c364d
--- /dev/null
+++ b/pkg/lisafs/testsuite/testsuite.go
@@ -0,0 +1,637 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package testsuite provides a integration testing suite for lisafs.
+// These tests are intended for servers serving the local filesystem.
+package testsuite
+
+import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "math/rand"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/syndtr/gocapability/capability"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/lisafs"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// Tester is the client code using this test suite. This interface abstracts
+// away all the caller specific details.
+type Tester interface {
+ // NewServer returns a new instance of the tester server.
+ NewServer(t *testing.T) *lisafs.Server
+
+ // LinkSupported returns true if the backing server supports LinkAt.
+ LinkSupported() bool
+
+ // SetUserGroupIDSupported returns true if the backing server supports
+ // changing UID/GID for files.
+ SetUserGroupIDSupported() bool
+}
+
+// RunAllLocalFSTests runs all local FS tests as subtests.
+func RunAllLocalFSTests(t *testing.T, tester Tester) {
+ for name, testFn := range localFSTests {
+ t.Run(name, func(t *testing.T) {
+ runServerClient(t, tester, testFn)
+ })
+ }
+}
+
+type testFunc func(context.Context, *testing.T, Tester, lisafs.ClientFD)
+
+var localFSTests map[string]testFunc = map[string]testFunc{
+ "Stat": testStat,
+ "RegularFileIO": testRegularFileIO,
+ "RegularFileOpen": testRegularFileOpen,
+ "SetStat": testSetStat,
+ "Allocate": testAllocate,
+ "StatFS": testStatFS,
+ "Unlink": testUnlink,
+ "Symlink": testSymlink,
+ "HardLink": testHardLink,
+ "Walk": testWalk,
+ "Rename": testRename,
+ "Mknod": testMknod,
+ "Getdents": testGetdents,
+}
+
+func runServerClient(t *testing.T, tester Tester, testFn testFunc) {
+ mountPath, err := ioutil.TempDir(os.Getenv("TEST_TMPDIR"), "")
+ if err != nil {
+ t.Fatalf("creation of temporary mountpoint failed: %v", err)
+ }
+ defer os.RemoveAll(mountPath)
+
+ // fsgofer should run with a umask of 0, because we want to preserve file
+ // modes exactly for testing purposes.
+ unix.Umask(0)
+
+ serverSocket, clientSocket, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+
+ server := tester.NewServer(t)
+ conn, err := server.CreateConnection(serverSocket, false /* readonly */)
+ if err != nil {
+ t.Fatalf("starting connection failed: %v", err)
+ return
+ }
+ server.StartConnection(conn)
+
+ c, root, err := lisafs.NewClient(clientSocket, mountPath)
+ if err != nil {
+ t.Fatalf("client creation failed: %v", err)
+ }
+
+ if !root.ControlFD.Ok() {
+ t.Fatalf("root control FD is not valid")
+ }
+ rootFile := c.NewFD(root.ControlFD)
+
+ ctx := context.Background()
+ testFn(ctx, t, tester, rootFile)
+ closeFD(ctx, t, rootFile)
+
+ c.Close() // This should trigger client and server shutdown.
+ server.Wait()
+}
+
+func closeFD(ctx context.Context, t testing.TB, fdLisa lisafs.ClientFD) {
+ if err := fdLisa.Close(ctx); err != nil {
+ t.Errorf("failed to close FD: %v", err)
+ }
+}
+
+func statTo(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, stat *linux.Statx) {
+ if err := fdLisa.StatTo(ctx, stat); err != nil {
+ t.Fatalf("stat failed: %v", err)
+ }
+}
+
+func openCreateFile(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, name string) (lisafs.ClientFD, linux.Statx, lisafs.ClientFD, int) {
+ child, childFD, childHostFD, err := fdLisa.OpenCreateAt(ctx, name, unix.O_RDWR, 0777, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid()))
+ if err != nil {
+ t.Fatalf("OpenCreateAt failed: %v", err)
+ }
+ if childHostFD == -1 {
+ t.Error("no host FD donated")
+ }
+ client := fdLisa.Client()
+ return client.NewFD(child.ControlFD), child.Stat, fdLisa.Client().NewFD(childFD), childHostFD
+}
+
+func openFile(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, flags uint32, isReg bool) (lisafs.ClientFD, int) {
+ newFD, hostFD, err := fdLisa.OpenAt(ctx, flags)
+ if err != nil {
+ t.Fatalf("OpenAt failed: %v", err)
+ }
+ if hostFD == -1 && isReg {
+ t.Error("no host FD donated")
+ }
+ return fdLisa.Client().NewFD(newFD), hostFD
+}
+
+func unlinkFile(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string, isDir bool) {
+ var flags uint32
+ if isDir {
+ flags = unix.AT_REMOVEDIR
+ }
+ if err := dir.UnlinkAt(ctx, name, flags); err != nil {
+ t.Errorf("unlinking file %s failed: %v", name, err)
+ }
+}
+
+func symlink(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name, target string) (lisafs.ClientFD, linux.Statx) {
+ linkIno, err := dir.SymlinkAt(ctx, name, target, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid()))
+ if err != nil {
+ t.Fatalf("symlink failed: %v", err)
+ }
+ return dir.Client().NewFD(linkIno.ControlFD), linkIno.Stat
+}
+
+func link(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string, target lisafs.ClientFD) (lisafs.ClientFD, linux.Statx) {
+ linkIno, err := dir.LinkAt(ctx, target.ID(), name)
+ if err != nil {
+ t.Fatalf("link failed: %v", err)
+ }
+ return dir.Client().NewFD(linkIno.ControlFD), linkIno.Stat
+}
+
+func mkdir(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string) (lisafs.ClientFD, linux.Statx) {
+ childIno, err := dir.MkdirAt(ctx, name, 0777, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid()))
+ if err != nil {
+ t.Fatalf("mkdir failed: %v", err)
+ }
+ return dir.Client().NewFD(childIno.ControlFD), childIno.Stat
+}
+
+func mknod(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string) (lisafs.ClientFD, linux.Statx) {
+ nodeIno, err := dir.MknodAt(ctx, name, unix.S_IFREG|0777, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid()), 0, 0)
+ if err != nil {
+ t.Fatalf("mknod failed: %v", err)
+ }
+ return dir.Client().NewFD(nodeIno.ControlFD), nodeIno.Stat
+}
+
+func walk(ctx context.Context, t *testing.T, dir lisafs.ClientFD, names []string) []lisafs.Inode {
+ _, inodes, err := dir.WalkMultiple(ctx, names)
+ if err != nil {
+ t.Fatalf("walk failed while trying to walk components %+v: %v", names, err)
+ }
+ return inodes
+}
+
+func walkStat(ctx context.Context, t *testing.T, dir lisafs.ClientFD, names []string) []linux.Statx {
+ stats, err := dir.WalkStat(ctx, names)
+ if err != nil {
+ t.Fatalf("walk failed while trying to walk components %+v: %v", names, err)
+ }
+ return stats
+}
+
+func writeFD(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, off uint64, buf []byte) error {
+ count, err := fdLisa.Write(ctx, buf, off)
+ if err != nil {
+ return err
+ }
+ if int(count) != len(buf) {
+ t.Errorf("partial write: buf size = %d, written = %d", len(buf), count)
+ }
+ return nil
+}
+
+func readFDAndCmp(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, off uint64, want []byte) {
+ buf := make([]byte, len(want))
+ n, err := fdLisa.Read(ctx, buf, off)
+ if err != nil {
+ t.Errorf("read failed: %v", err)
+ return
+ }
+ if int(n) != len(want) {
+ t.Errorf("partial read: buf size = %d, read = %d", len(want), n)
+ return
+ }
+ if bytes.Compare(buf, want) != 0 {
+ t.Errorf("bytes read differ from what was expected: want = %v, got = %v", want, buf)
+ }
+}
+
+func allocateAndVerify(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, off uint64, length uint64) {
+ if err := fdLisa.Allocate(ctx, 0, off, length); err != nil {
+ t.Fatalf("fallocate failed: %v", err)
+ }
+
+ var stat linux.Statx
+ statTo(ctx, t, fdLisa, &stat)
+ if want := off + length; stat.Size != want {
+ t.Errorf("incorrect file size after allocate: expected %d, got %d", off+length, stat.Size)
+ }
+}
+
+func cmpStatx(t *testing.T, want, got linux.Statx) {
+ if got.Mask&unix.STATX_MODE != 0 && want.Mask&unix.STATX_MODE != 0 {
+ if got.Mode != want.Mode {
+ t.Errorf("mode differs: want %d, got %d", want.Mode, got.Mode)
+ }
+ }
+ if got.Mask&unix.STATX_INO != 0 && want.Mask&unix.STATX_INO != 0 {
+ if got.Ino != want.Ino {
+ t.Errorf("inode number differs: want %d, got %d", want.Ino, got.Ino)
+ }
+ }
+ if got.Mask&unix.STATX_NLINK != 0 && want.Mask&unix.STATX_NLINK != 0 {
+ if got.Nlink != want.Nlink {
+ t.Errorf("nlink differs: want %d, got %d", want.Nlink, got.Nlink)
+ }
+ }
+ if got.Mask&unix.STATX_UID != 0 && want.Mask&unix.STATX_UID != 0 {
+ if got.UID != want.UID {
+ t.Errorf("UID differs: want %d, got %d", want.UID, got.UID)
+ }
+ }
+ if got.Mask&unix.STATX_GID != 0 && want.Mask&unix.STATX_GID != 0 {
+ if got.GID != want.GID {
+ t.Errorf("GID differs: want %d, got %d", want.GID, got.GID)
+ }
+ }
+ if got.Mask&unix.STATX_SIZE != 0 && want.Mask&unix.STATX_SIZE != 0 {
+ if got.Size != want.Size {
+ t.Errorf("size differs: want %d, got %d", want.Size, got.Size)
+ }
+ }
+ if got.Mask&unix.STATX_BLOCKS != 0 && want.Mask&unix.STATX_BLOCKS != 0 {
+ if got.Blocks != want.Blocks {
+ t.Errorf("blocks differs: want %d, got %d", want.Blocks, got.Blocks)
+ }
+ }
+ if got.Mask&unix.STATX_ATIME != 0 && want.Mask&unix.STATX_ATIME != 0 {
+ if got.Atime != want.Atime {
+ t.Errorf("atime differs: want %d, got %d", want.Atime, got.Atime)
+ }
+ }
+ if got.Mask&unix.STATX_MTIME != 0 && want.Mask&unix.STATX_MTIME != 0 {
+ if got.Mtime != want.Mtime {
+ t.Errorf("mtime differs: want %d, got %d", want.Mtime, got.Mtime)
+ }
+ }
+ if got.Mask&unix.STATX_CTIME != 0 && want.Mask&unix.STATX_CTIME != 0 {
+ if got.Ctime != want.Ctime {
+ t.Errorf("ctime differs: want %d, got %d", want.Ctime, got.Ctime)
+ }
+ }
+}
+
+func hasCapability(c capability.Cap) bool {
+ caps, err := capability.NewPid2(os.Getpid())
+ if err != nil {
+ return false
+ }
+ if err := caps.Load(); err != nil {
+ return false
+ }
+ return caps.Get(capability.EFFECTIVE, c)
+}
+
+func testStat(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ var rootStat linux.Statx
+ if err := root.StatTo(ctx, &rootStat); err != nil {
+ t.Errorf("stat on the root dir failed: %v", err)
+ }
+
+ if ftype := rootStat.Mode & unix.S_IFMT; ftype != unix.S_IFDIR {
+ t.Errorf("root inode is not a directory, file type = %d", ftype)
+ }
+}
+
+func testRegularFileIO(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ name := "tempFile"
+ controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name)
+ defer closeFD(ctx, t, controlFile)
+ defer closeFD(ctx, t, fd)
+ defer unix.Close(hostFD)
+
+ // Test Read/Write RPCs with 2MB of data to test IO in chunks.
+ data := make([]byte, 1<<21)
+ rand.Read(data)
+ if err := writeFD(ctx, t, fd, 0, data); err != nil {
+ t.Fatalf("write failed: %v", err)
+ }
+ readFDAndCmp(ctx, t, fd, 0, data)
+ readFDAndCmp(ctx, t, fd, 50, data[50:])
+
+ // Make sure the host FD is configured properly.
+ hostReadData := make([]byte, len(data))
+ if n, err := unix.Pread(hostFD, hostReadData, 0); err != nil {
+ t.Errorf("host read failed: %v", err)
+ } else if n != len(hostReadData) {
+ t.Errorf("partial read: buf size = %d, read = %d", len(hostReadData), n)
+ } else if bytes.Compare(hostReadData, data) != 0 {
+ t.Errorf("bytes read differ from what was expected: want = %v, got = %v", data, hostReadData)
+ }
+
+ // Test syncing the writable FD.
+ if err := fd.Sync(ctx); err != nil {
+ t.Errorf("syncing the FD failed: %v", err)
+ }
+}
+
+func testRegularFileOpen(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ name := "tempFile"
+ controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name)
+ defer closeFD(ctx, t, controlFile)
+ defer closeFD(ctx, t, fd)
+ defer unix.Close(hostFD)
+
+ // Open a readonly FD and try writing to it to get an EBADF.
+ roFile, roHostFD := openFile(ctx, t, controlFile, unix.O_RDONLY, true /* isReg */)
+ defer closeFD(ctx, t, roFile)
+ defer unix.Close(roHostFD)
+ if err := writeFD(ctx, t, roFile, 0, []byte{1, 2, 3}); err != unix.EBADF {
+ t.Errorf("writing to read only FD should generate EBADF, but got %v", err)
+ }
+}
+
+func testSetStat(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ name := "tempFile"
+ controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name)
+ defer closeFD(ctx, t, controlFile)
+ defer closeFD(ctx, t, fd)
+ defer unix.Close(hostFD)
+
+ now := time.Now()
+ wantStat := linux.Statx{
+ Mask: unix.STATX_MODE | unix.STATX_ATIME | unix.STATX_MTIME | unix.STATX_SIZE,
+ Mode: 0760,
+ UID: uint32(unix.Getuid()),
+ GID: uint32(unix.Getgid()),
+ Size: 50,
+ Atime: linux.NsecToStatxTimestamp(now.UnixNano()),
+ Mtime: linux.NsecToStatxTimestamp(now.UnixNano()),
+ }
+ if tester.SetUserGroupIDSupported() {
+ wantStat.Mask |= unix.STATX_UID | unix.STATX_GID
+ }
+ failureMask, failureErr, err := controlFile.SetStat(ctx, &wantStat)
+ if err != nil {
+ t.Fatalf("setstat failed: %v", err)
+ }
+ if failureMask != 0 {
+ t.Fatalf("some setstat operations failed: failureMask = %#b, failureErr = %v", failureMask, failureErr)
+ }
+
+ // Verify that attributes were updated.
+ var gotStat linux.Statx
+ statTo(ctx, t, controlFile, &gotStat)
+ if gotStat.Mode&07777 != wantStat.Mode ||
+ gotStat.Size != wantStat.Size ||
+ gotStat.Atime.ToNsec() != wantStat.Atime.ToNsec() ||
+ gotStat.Mtime.ToNsec() != wantStat.Mtime.ToNsec() ||
+ (tester.SetUserGroupIDSupported() && (uint32(gotStat.UID) != wantStat.UID || uint32(gotStat.GID) != wantStat.GID)) {
+ t.Errorf("setStat did not update file correctly: setStat = %+v, stat = %+v", wantStat, gotStat)
+ }
+}
+
+func testAllocate(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ name := "tempFile"
+ controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name)
+ defer closeFD(ctx, t, controlFile)
+ defer closeFD(ctx, t, fd)
+ defer unix.Close(hostFD)
+
+ allocateAndVerify(ctx, t, fd, 0, 40)
+ allocateAndVerify(ctx, t, fd, 20, 100)
+}
+
+func testStatFS(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ var statFS lisafs.StatFS
+ if err := root.StatFSTo(ctx, &statFS); err != nil {
+ t.Errorf("statfs failed: %v", err)
+ }
+}
+
+func testUnlink(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ name := "tempFile"
+ controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name)
+ defer closeFD(ctx, t, controlFile)
+ defer closeFD(ctx, t, fd)
+ defer unix.Close(hostFD)
+
+ unlinkFile(ctx, t, root, name, false /* isDir */)
+ if inodes := walk(ctx, t, root, []string{name}); len(inodes) > 0 {
+ t.Errorf("deleted file should not be generating inodes on walk: %+v", inodes)
+ }
+}
+
+func testSymlink(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ target := "/tmp/some/path"
+ name := "symlinkFile"
+ link, linkStat := symlink(ctx, t, root, name, target)
+ defer closeFD(ctx, t, link)
+
+ if linkStat.Mode&unix.S_IFMT != unix.S_IFLNK {
+ t.Errorf("stat return from symlink RPC indicates that the inode is not a symlink: mode = %d", linkStat.Mode)
+ }
+
+ if gotTarget, err := link.ReadLinkAt(ctx); err != nil {
+ t.Fatalf("readlink failed: %v", err)
+ } else if gotTarget != target {
+ t.Errorf("readlink return incorrect target: expected %q, got %q", target, gotTarget)
+ }
+}
+
+func testHardLink(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ if !tester.LinkSupported() {
+ t.Skipf("server does not support LinkAt RPC")
+ }
+ if !hasCapability(capability.CAP_DAC_READ_SEARCH) {
+ t.Skipf("TestHardLink requires CAP_DAC_READ_SEARCH, running as %d", unix.Getuid())
+ }
+ name := "tempFile"
+ controlFile, fileIno, fd, hostFD := openCreateFile(ctx, t, root, name)
+ defer closeFD(ctx, t, controlFile)
+ defer closeFD(ctx, t, fd)
+ defer unix.Close(hostFD)
+
+ link, linkStat := link(ctx, t, root, name, controlFile)
+ defer closeFD(ctx, t, link)
+
+ if linkStat.Ino != fileIno.Ino {
+ t.Errorf("hard linked files have different inode numbers: %d %d", linkStat.Ino, fileIno.Ino)
+ }
+ if linkStat.DevMinor != fileIno.DevMinor {
+ t.Errorf("hard linked files have different minor device numbers: %d %d", linkStat.DevMinor, fileIno.DevMinor)
+ }
+ if linkStat.DevMajor != fileIno.DevMajor {
+ t.Errorf("hard linked files have different major device numbers: %d %d", linkStat.DevMajor, fileIno.DevMajor)
+ }
+}
+
+func testWalk(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ // Create 10 nested directories.
+ n := 10
+ curDir := root
+
+ dirNames := make([]string, 0, n)
+ for i := 0; i < n; i++ {
+ name := fmt.Sprintf("tmpdir-%d", i)
+ childDir, _ := mkdir(ctx, t, curDir, name)
+ defer closeFD(ctx, t, childDir)
+ defer unlinkFile(ctx, t, curDir, name, true /* isDir */)
+
+ curDir = childDir
+ dirNames = append(dirNames, name)
+ }
+
+ // Walk all these directories. Add some junk at the end which should not be
+ // walked on.
+ dirNames = append(dirNames, []string{"a", "b", "c"}...)
+ inodes := walk(ctx, t, root, dirNames)
+ if len(inodes) != n {
+ t.Errorf("walk returned the incorrect number of inodes: wanted %d, got %d", n, len(inodes))
+ }
+
+ // Close all control FDs and collect stat results for all dirs including
+ // the root directory.
+ dirStats := make([]linux.Statx, 0, n+1)
+ var stat linux.Statx
+ statTo(ctx, t, root, &stat)
+ dirStats = append(dirStats, stat)
+ for _, inode := range inodes {
+ dirStats = append(dirStats, inode.Stat)
+ closeFD(ctx, t, root.Client().NewFD(inode.ControlFD))
+ }
+
+ // Test WalkStat which additonally returns Statx for root because the first
+ // path component is "".
+ dirNames = append([]string{""}, dirNames...)
+ gotStats := walkStat(ctx, t, root, dirNames)
+ if len(gotStats) != len(dirStats) {
+ t.Errorf("walkStat returned the incorrect number of statx: wanted %d, got %d", len(dirStats), len(gotStats))
+ } else {
+ for i := range gotStats {
+ cmpStatx(t, dirStats[i], gotStats[i])
+ }
+ }
+}
+
+func testRename(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ name := "tempFile"
+ tempFile, _, fd, hostFD := openCreateFile(ctx, t, root, name)
+ defer closeFD(ctx, t, tempFile)
+ defer closeFD(ctx, t, fd)
+ defer unix.Close(hostFD)
+
+ tempDir, _ := mkdir(ctx, t, root, "tempDir")
+ defer closeFD(ctx, t, tempDir)
+
+ // Move tempFile into tempDir.
+ if err := tempFile.RenameTo(ctx, tempDir.ID(), "movedFile"); err != nil {
+ t.Fatalf("rename failed: %v", err)
+ }
+
+ inodes := walkStat(ctx, t, root, []string{"tempDir", "movedFile"})
+ if len(inodes) != 2 {
+ t.Errorf("expected 2 files on walk but only found %d", len(inodes))
+ }
+}
+
+func testMknod(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ name := "namedPipe"
+ pipeFile, pipeStat := mknod(ctx, t, root, name)
+ defer closeFD(ctx, t, pipeFile)
+
+ var stat linux.Statx
+ statTo(ctx, t, pipeFile, &stat)
+
+ if stat.Mode != pipeStat.Mode {
+ t.Errorf("mknod mode is incorrect: want %d, got %d", pipeStat.Mode, stat.Mode)
+ }
+ if stat.UID != pipeStat.UID {
+ t.Errorf("mknod UID is incorrect: want %d, got %d", pipeStat.UID, stat.UID)
+ }
+ if stat.GID != pipeStat.GID {
+ t.Errorf("mknod GID is incorrect: want %d, got %d", pipeStat.GID, stat.GID)
+ }
+}
+
+func testGetdents(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ tempDir, _ := mkdir(ctx, t, root, "tempDir")
+ defer closeFD(ctx, t, tempDir)
+ defer unlinkFile(ctx, t, root, "tempDir", true /* isDir */)
+
+ // Create 10 files in tempDir.
+ n := 10
+ fileStats := make(map[string]linux.Statx)
+ for i := 0; i < n; i++ {
+ name := fmt.Sprintf("file-%d", i)
+ newFile, fileStat := mknod(ctx, t, tempDir, name)
+ defer closeFD(ctx, t, newFile)
+ defer unlinkFile(ctx, t, tempDir, name, false /* isDir */)
+
+ fileStats[name] = fileStat
+ }
+
+ // Use opened directory FD for getdents.
+ openDirFile, _ := openFile(ctx, t, tempDir, unix.O_RDONLY, false /* isReg */)
+ defer closeFD(ctx, t, openDirFile)
+
+ dirents := make([]lisafs.Dirent64, 0, n)
+ for i := 0; i < n+2; i++ {
+ gotDirents, err := openDirFile.Getdents64(ctx, 40)
+ if err != nil {
+ t.Fatalf("getdents failed: %v", err)
+ }
+ if len(gotDirents) == 0 {
+ break
+ }
+ for _, dirent := range gotDirents {
+ if dirent.Name != "." && dirent.Name != ".." {
+ dirents = append(dirents, dirent)
+ }
+ }
+ }
+
+ if len(dirents) != n {
+ t.Errorf("got incorrect number of dirents: wanted %d, got %d", n, len(dirents))
+ }
+ for _, dirent := range dirents {
+ stat, ok := fileStats[string(dirent.Name)]
+ if !ok {
+ t.Errorf("received a dirent that was not created: %+v", dirent)
+ continue
+ }
+
+ if dirent.Type != unix.DT_REG {
+ t.Errorf("dirent type of %s is incorrect: %d", dirent.Name, dirent.Type)
+ }
+ if uint64(dirent.Ino) != stat.Ino {
+ t.Errorf("dirent ino of %s is incorrect: want %d, got %d", dirent.Name, stat.Ino, dirent.Ino)
+ }
+ if uint32(dirent.DevMinor) != stat.DevMinor {
+ t.Errorf("dirent dev minor of %s is incorrect: want %d, got %d", dirent.Name, stat.DevMinor, dirent.DevMinor)
+ }
+ if uint32(dirent.DevMajor) != stat.DevMajor {
+ t.Errorf("dirent dev major of %s is incorrect: want %d, got %d", dirent.Name, stat.DevMajor, dirent.DevMajor)
+ }
+ }
+}
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
index 0b961d3d9..6358ad8e9 100644
--- a/pkg/merkletree/merkletree.go
+++ b/pkg/merkletree/merkletree.go
@@ -384,6 +384,14 @@ func verifyMetadata(params *VerifyParams, layout *Layout) error {
return descriptor.verify(params.Expected, params.HashAlgorithms)
}
+// cachedHashes stores verified hashes from a previous hash step.
+type cachedHashes struct {
+ // offset is the offset of cached hash in each level.
+ offset []int64
+ // hash is the verified cache for each level from previous hash steps.
+ hash [][]byte
+}
+
// Verify verifies the content read from data with offset. The content is
// verified against tree. If content spans across multiple blocks, each block is
// verified. Verification fails if the hash of the data does not match the tree
@@ -409,29 +417,32 @@ func Verify(params *VerifyParams) (int64, error) {
firstDataBlock := params.ReadOffset / layout.blockSize
lastDataBlock := (params.ReadOffset + params.ReadSize - 1) / layout.blockSize
- buf := make([]byte, layout.blockSize)
- var readErr error
- total := int64(0)
+ size := (lastDataBlock - firstDataBlock + 1) * layout.blockSize
+ retBuf := make([]byte, size)
+ n, err := params.File.ReadAt(retBuf, firstDataBlock*layout.blockSize)
+ if err != nil && err != io.EOF {
+ return 0, err
+ }
+ total := int64(n)
+ bytesRead := int64(0)
+
+ // Only cache hash results if reading more than a block.
+ var ch *cachedHashes
+ if lastDataBlock > firstDataBlock {
+ ch = &cachedHashes{
+ offset: make([]int64, layout.numLevels()),
+ hash: make([][]byte, layout.numLevels()),
+ }
+ }
for i := firstDataBlock; i <= lastDataBlock; i++ {
+ // Reach the end of file during verification.
+ if total <= 0 {
+ return bytesRead, io.EOF
+ }
// Read a block that includes all or part of target range in
// input data.
- bytesRead, err := params.File.ReadAt(buf, i*layout.blockSize)
- readErr = err
- // If at the end of input data and all previous blocks are
- // verified, return the verified input data and EOF.
- if readErr == io.EOF && bytesRead == 0 {
- break
- }
- if readErr != nil && readErr != io.EOF {
- return 0, fmt.Errorf("read from data failed: %w", err)
- }
- // If this is the end of file, zero the remaining bytes in buf,
- // otherwise they are still from the previous block.
- if bytesRead < len(buf) {
- for j := bytesRead; j < len(buf); j++ {
- buf[j] = 0
- }
- }
+ buf := retBuf[(i-firstDataBlock)*layout.blockSize : (i-firstDataBlock+1)*layout.blockSize]
+
descriptor := VerityDescriptor{
Name: params.Name,
FileSize: params.Size,
@@ -441,8 +452,8 @@ func Verify(params *VerifyParams) (int64, error) {
SymlinkTarget: params.SymlinkTarget,
Children: params.Children,
}
- if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.HashAlgorithms, params.Expected); err != nil {
- return 0, err
+ if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.HashAlgorithms, params.Expected, ch); err != nil {
+ return bytesRead, err
}
// startOff is the beginning of the read range within the
@@ -459,22 +470,24 @@ func Verify(params *VerifyParams) (int64, error) {
if i == lastDataBlock {
endOff = (params.ReadOffset+params.ReadSize-1)%layout.blockSize + 1
}
+
// If the provided size exceeds the end of input data, we should
// only copy the parts in buf that's part of input data.
- if startOff > int64(bytesRead) {
- startOff = int64(bytesRead)
+ if startOff > total {
+ startOff = total
}
- if endOff > int64(bytesRead) {
- endOff = int64(bytesRead)
+ if endOff > total {
+ endOff = total
}
+
n, err := params.Out.Write(buf[startOff:endOff])
if err != nil {
- return total, err
+ return bytesRead, err
}
- total += int64(n)
-
+ bytesRead += int64(n)
+ total -= endOff
}
- return total, readErr
+ return bytesRead, nil
}
// verifyBlock verifies a block against tree. index is the number of block in
@@ -482,7 +495,7 @@ func Verify(params *VerifyParams) (int64, error) {
// fails if the calculated hash from block is different from any level of
// hashes stored in tree. And the final root hash is compared with
// expected.
-func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, hashAlgorithms int, expected []byte) error {
+func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, hashAlgorithms int, expected []byte, ch *cachedHashes) error {
if len(dataBlock) != int(layout.blockSize) {
return fmt.Errorf("incorrect block size")
}
@@ -491,6 +504,12 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout,
treeBlock := make([]byte, layout.blockSize)
var digest []byte
for level := 0; level < layout.numLevels(); level++ {
+ // No need to verify remaining levels if the current block has
+ // been verified in a previous call and cached.
+ if ch != nil && ch.offset[level] == layout.digestOffset(level, blockIndex) && ch.hash[level] != nil {
+ break
+ }
+
// Calculate hash.
if level == 0 {
h, err := hashData(dataBlock, hashAlgorithms)
@@ -521,11 +540,19 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout,
if !bytes.Equal(digest, expectedDigest) {
return fmt.Errorf("verification failed")
}
+ if ch != nil {
+ ch.offset[level] = layout.digestOffset(level, blockIndex)
+ ch.hash[level] = expectedDigest
+ }
blockIndex = blockIndex / layout.hashesPerBlock()
}
// Verification for the tree succeeded. Now hash the descriptor with
// the root hash and compare it with expected.
- descriptor.RootHash = digest
+ if ch != nil {
+ descriptor.RootHash = ch.hash[layout.rootLevel()]
+ } else {
+ descriptor.RootHash = digest
+ }
return descriptor.verify(expected, hashAlgorithms)
}
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index 764f1f970..d618da820 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -115,7 +115,7 @@ type Client struct {
// channels is the set of all initialized channels.
channels []*channel
- // availableChannels is a FIFO of inactive channels.
+ // availableChannels is a LIFO of inactive channels.
availableChannels []*channel
// -- below corresponds to sendRecvLegacy --
@@ -528,7 +528,7 @@ func (c *Client) sendRecvChannel(t message, r message) error {
}
// Send the request and receive the server's response.
- rsz, err := ch.send(t)
+ rsz, err := ch.send(t, false /* isServer */)
if err != nil {
// See above.
c.channelsMu.Lock()
diff --git a/pkg/p9/file.go b/pkg/p9/file.go
index 97e0231d6..8d6af2d6b 100644
--- a/pkg/p9/file.go
+++ b/pkg/p9/file.go
@@ -325,6 +325,12 @@ func (*DisallowServerCalls) Renamed(File, string) {
func DefaultMultiGetAttr(start File, names []string) ([]FullStat, error) {
stats := make([]FullStat, 0, len(names))
parent := start
+ closeParent := func() {
+ if parent != start {
+ _ = parent.Close()
+ }
+ }
+ defer closeParent()
mask := AttrMaskAll()
for i, name := range names {
if len(name) == 0 && i == 0 {
@@ -340,15 +346,14 @@ func DefaultMultiGetAttr(start File, names []string) ([]FullStat, error) {
continue
}
qids, child, valid, attr, err := parent.WalkGetAttr([]string{name})
- if parent != start {
- _ = parent.Close()
- }
if err != nil {
if errors.Is(err, unix.ENOENT) {
return stats, nil
}
return nil, err
}
+ closeParent()
+ parent = child
stats = append(stats, FullStat{
QID: qids[0],
Valid: valid,
@@ -357,13 +362,8 @@ func DefaultMultiGetAttr(start File, names []string) ([]FullStat, error) {
if attr.Mode.FileType() != ModeDirectory {
// Doesn't need to continue if entry is not a dir. Including symlinks
// that cannot be followed.
- _ = child.Close()
break
}
- parent = child
- }
- if parent != start {
- _ = parent.Close()
}
return stats, nil
}
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
index 161b451cc..a8f8a9d03 100644
--- a/pkg/p9/handlers.go
+++ b/pkg/p9/handlers.go
@@ -45,6 +45,8 @@ func ExtractErrno(err error) unix.Errno {
// Attempt to unwrap.
switch e := err.(type) {
+ case *errors.Error:
+ return unix.Errno(e.Errno())
case unix.Errno:
return e
case *os.PathError:
diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go
index 802254a90..69a9f2537 100644
--- a/pkg/p9/transport_flipcall.go
+++ b/pkg/p9/transport_flipcall.go
@@ -85,7 +85,7 @@ func (ch *channel) service(cs *connState) error {
}
r := cs.handle(m)
msgRegistry.put(m)
- rsz, err = ch.send(r)
+ rsz, err = ch.send(r, true /* isServer */)
if err != nil {
return err
}
@@ -122,7 +122,7 @@ func (ch *channel) Close() error {
//
// The return value is the size of the received response. Not that in the
// server case, this is the size of the next request.
-func (ch *channel) send(m message) (uint32, error) {
+func (ch *channel) send(m message, isServer bool) (uint32, error) {
if log.IsLogging(log.Debug) {
log.Debugf("send [channel @%p] %s", ch, m.String())
}
@@ -162,7 +162,11 @@ func (ch *channel) send(m message) (uint32, error) {
}
// Perform the one-shot communication.
- return ch.data.SendRecv(ssz)
+ if isServer {
+ return ch.data.SendRecv(ssz)
+ }
+ // RPCs are expected to return quickly rather than block.
+ return ch.data.SendRecvFast(ssz)
}
// recv decodes a message that exists on the channel.
diff --git a/pkg/ring0/defs.go b/pkg/ring0/defs.go
index b6e2012e8..38ce9be1e 100644
--- a/pkg/ring0/defs.go
+++ b/pkg/ring0/defs.go
@@ -77,6 +77,9 @@ type CPU struct {
// calls and exceptions via the Registers function.
registers arch.Registers
+ // floatingPointState holds floating point state.
+ floatingPointState fpu.State
+
// hooks are kernel hooks.
hooks Hooks
}
@@ -90,6 +93,15 @@ func (c *CPU) Registers() *arch.Registers {
return &c.registers
}
+// FloatingPointState returns the kernel floating point state.
+//
+// This is explicitly safe to call during KernelException and KernelSyscall.
+//
+//go:nosplit
+func (c *CPU) FloatingPointState() *fpu.State {
+ return &c.floatingPointState
+}
+
// SwitchOpts are passed to the Switch function.
type SwitchOpts struct {
// Registers are the user register state.
diff --git a/pkg/ring0/defs_amd64.go b/pkg/ring0/defs_amd64.go
index 24f6e4cde..81e90dbf7 100644
--- a/pkg/ring0/defs_amd64.go
+++ b/pkg/ring0/defs_amd64.go
@@ -116,6 +116,11 @@ type CPUArchState struct {
errorType uintptr
*kernelEntry
+
+ // Copies of global variables, stored in CPU so that they can be used by
+ // syscall and exception handlers (in the upper address space).
+ hasXSAVE bool
+ hasXSAVEOPT bool
}
// ErrorCode returns the last error code.
diff --git a/pkg/ring0/entry_amd64.go b/pkg/ring0/entry_amd64.go
index afd646b0b..13ad4e4df 100644
--- a/pkg/ring0/entry_amd64.go
+++ b/pkg/ring0/entry_amd64.go
@@ -39,11 +39,6 @@ func sysenter()
// assembly to get the ABI0 (i.e., primary) address.
func addrOfSysenter() uintptr
-// swapgs swaps the current GS value.
-//
-// This must be called prior to sysret/iret.
-func swapgs()
-
// jumpToKernel jumps to the kernel version of the current RIP.
func jumpToKernel()
diff --git a/pkg/ring0/entry_amd64.s b/pkg/ring0/entry_amd64.s
index 520bd9f57..d2913f190 100644
--- a/pkg/ring0/entry_amd64.s
+++ b/pkg/ring0/entry_amd64.s
@@ -142,8 +142,103 @@ TEXT ·jumpToUser(SB),NOSPLIT,$0
MOVQ AX, 0(SP)
RET
+// See kernel_amd64.go.
+//
+// The 16-byte frame size is for the saved values of MXCSR and the x87 control
+// word.
+TEXT ·doSwitchToUser(SB),NOSPLIT,$16-48
+ // We are passed pointers to heap objects, but do not store them in our
+ // local frame.
+ NO_LOCAL_POINTERS
+
+ // MXCSR and the x87 control word are the only floating point state
+ // that is callee-save and thus we must save.
+ STMXCSR mxcsr-0(SP)
+ FSTCW cw-8(SP)
+
+ // Restore application floating point state.
+ MOVQ cpu+0(FP), SI
+ MOVQ fpState+16(FP), DI
+ MOVB ·hasXSAVE(SB), BX
+ TESTB BX, BX
+ JZ no_xrstor
+ // Use xrstor to restore all available fp state. For now, we restore
+ // everything unconditionally by setting the implicit operand edx:eax
+ // (the "requested feature bitmap") to all 1's.
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x2f // XRSTOR64 0(DI)
+ JMP fprestore_done
+no_xrstor:
+ // Fall back to fxrstor if xsave is not available.
+ FXRSTOR64 0(DI)
+fprestore_done:
+
+ // Set application GS.
+ MOVQ regs+8(FP), R8
+ SWAP_GS()
+ MOVQ PTRACE_GS_BASE(R8), AX
+ PUSHQ AX
+ CALL ·writeGS(SB)
+ POPQ AX
+
+ // Call sysret() or iret().
+ MOVQ userCR3+24(FP), CX
+ MOVQ needIRET+32(FP), R9
+ ADDQ $-32, SP
+ MOVQ SI, 0(SP) // cpu
+ MOVQ R8, 8(SP) // regs
+ MOVQ CX, 16(SP) // userCR3
+ TESTQ R9, R9
+ JNZ do_iret
+ CALL ·sysret(SB)
+ JMP done_sysret_or_iret
+do_iret:
+ CALL ·iret(SB)
+done_sysret_or_iret:
+ MOVQ 24(SP), AX // vector
+ ADDQ $32, SP
+ MOVQ AX, vector+40(FP)
+
+ // Save application floating point state.
+ MOVQ fpState+16(FP), DI
+ MOVB ·hasXSAVE(SB), BX
+ MOVB ·hasXSAVEOPT(SB), CX
+ TESTB BX, BX
+ JZ no_xsave
+ // Use xsave/xsaveopt to save all extended state.
+ // We save everything unconditionally by setting RFBM to all 1's.
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ TESTB CX, CX
+ JZ no_xsaveopt
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x37; // XSAVEOPT64 0(DI)
+ JMP fpsave_done
+no_xsaveopt:
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x27; // XSAVE64 0(DI)
+ JMP fpsave_done
+no_xsave:
+ FXSAVE64 0(DI)
+fpsave_done:
+
+ // Restore MXCSR and the x87 control word after one of the two floating
+ // point save cases above, to ensure the application versions are saved
+ // before being clobbered here.
+ LDMXCSR mxcsr-0(SP)
+
+ // FLDCW is a "waiting" x87 instruction, meaning it checks for pending
+ // unmasked exceptions before executing. Thus if userspace has unmasked
+ // an exception and has one pending, it can be raised by FLDCW even
+ // though the new control word will mask exceptions. To prevent this,
+ // we must first clear pending exceptions (which will be restored by
+ // XRSTOR, et al).
+ BYTE $0xDB; BYTE $0xE2; // FNCLEX
+ FLDCW cw-8(SP)
+
+ RET
+
// See entry_amd64.go.
-TEXT ·sysret(SB),NOSPLIT,$0-24
+TEXT ·sysret(SB),NOSPLIT,$0-32
// Set application FS. We can't do this in Go because Go code needs FS.
MOVQ regs+8(FP), AX
MOVQ PTRACE_FS_BASE(AX), AX
@@ -182,9 +277,11 @@ TEXT ·sysret(SB),NOSPLIT,$0-24
POPQ AX // Restore AX.
POPQ SP // Restore SP.
SYSRET64()
+ // sysenter or exception will write our return value and return to our
+ // caller.
// See entry_amd64.go.
-TEXT ·iret(SB),NOSPLIT,$0-24
+TEXT ·iret(SB),NOSPLIT,$0-32
// Set application FS. We can't do this in Go because Go code needs FS.
MOVQ regs+8(FP), AX
MOVQ PTRACE_FS_BASE(AX), AX
@@ -220,6 +317,8 @@ TEXT ·iret(SB),NOSPLIT,$0-24
WRITE_CR3() // Switch to userCR3.
POPQ AX // Restore AX.
IRET()
+ // sysenter or exception will write our return value and return to our
+ // caller.
// See entry_amd64.go.
TEXT ·resume(SB),NOSPLIT,$0
@@ -324,11 +423,39 @@ kernel:
MOVQ $0, CPU_ERROR_CODE(AX) // Clear error code.
MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel.
+ // Save floating point state. CPU.floatingPointState is a slice, so the
+ // first word of CPU.floatingPointState is a pointer to the destination
+ // array.
+ MOVQ CPU_FPU_STATE(AX), DI
+ MOVB CPU_HAS_XSAVE(AX), BX
+ MOVB CPU_HAS_XSAVEOPT(AX), CX
+ TESTB BX, BX
+ JZ no_xsave
+ // Use xsave/xsaveopt to save all extended state.
+ // We save everything unconditionally by setting RFBM to all 1's.
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ TESTB CX, CX
+ JZ no_xsaveopt
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x37; // XSAVEOPT64 0(DI)
+ JMP fpsave_done
+no_xsaveopt:
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x27; // XSAVE64 0(DI)
+ JMP fpsave_done
+no_xsave:
+ FXSAVE64 0(DI)
+fpsave_done:
+
// Call the syscall trampoline.
LOAD_KERNEL_STACK(GS)
- PUSHQ AX // First argument (vCPU).
- CALL ·kernelSyscall(SB) // Call the trampoline.
- POPQ AX // Pop vCPU.
+ MOVQ ENTRY_CPU_SELF(GS), AX // AX contains the vCPU.
+ PUSHQ AX // First argument (vCPU).
+ CALL ·kernelSyscall(SB) // Call the trampoline.
+ POPQ AX // Pop vCPU.
+
+ // We only trigger a bluepill entry in the bluepill function, and can
+ // therefore be guaranteed that there is no floating point state to be
+ // loaded on resuming from halt.
JMP ·resume(SB)
ADDR_OF_FUNC(·addrOfSysenter(SB), ·sysenter(SB));
@@ -416,15 +543,43 @@ kernel:
MOVQ 8(SP), BX // Load the error code.
MOVQ BX, CPU_ERROR_CODE(AX) // Copy out to the CPU.
MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel.
- MOVQ 0(SP), BX // BX contains the vector.
+
+ // Save floating point state. CPU.floatingPointState is a slice, so the
+ // first word of CPU.floatingPointState is a pointer to the destination
+ // array.
+ MOVQ CPU_FPU_STATE(AX), DI
+ MOVB CPU_HAS_XSAVE(AX), BX
+ MOVB CPU_HAS_XSAVEOPT(AX), CX
+ TESTB BX, BX
+ JZ no_xsave
+ // Use xsave/xsaveopt to save all extended state.
+ // We save everything unconditionally by setting RFBM to all 1's.
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ TESTB CX, CX
+ JZ no_xsaveopt
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x37; // XSAVEOPT64 0(DI)
+ JMP fpsave_done
+no_xsaveopt:
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x27; // XSAVE64 0(DI)
+ JMP fpsave_done
+no_xsave:
+ FXSAVE64 0(DI)
+fpsave_done:
// Call the exception trampoline.
+ MOVQ 0(SP), BX // BX contains the vector.
LOAD_KERNEL_STACK(GS)
- PUSHQ BX // Second argument (vector).
- PUSHQ AX // First argument (vCPU).
- CALL ·kernelException(SB) // Call the trampoline.
- POPQ BX // Pop vector.
- POPQ AX // Pop vCPU.
+ MOVQ ENTRY_CPU_SELF(GS), AX // AX contains the vCPU.
+ PUSHQ BX // Second argument (vector).
+ PUSHQ AX // First argument (vCPU).
+ CALL ·kernelException(SB) // Call the trampoline.
+ POPQ BX // Pop vector.
+ POPQ AX // Pop vCPU.
+
+ // We only trigger a bluepill entry in the bluepill function, and can
+ // therefore be guaranteed that there is no floating point state to be
+ // loaded on resuming from halt.
JMP ·resume(SB)
#define EXCEPTION_WITH_ERROR(value, symbol, addr) \
diff --git a/pkg/ring0/kernel.go b/pkg/ring0/kernel.go
index 292f9d0cc..e7dd84929 100644
--- a/pkg/ring0/kernel.go
+++ b/pkg/ring0/kernel.go
@@ -14,6 +14,10 @@
package ring0
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
+)
+
// Init initializes a new kernel.
//
//go:nosplit
@@ -80,6 +84,7 @@ func (c *CPU) Init(k *Kernel, cpuID int, hooks Hooks) {
c.self = c // Set self reference.
c.kernel = k // Set kernel reference.
c.init(cpuID) // Perform architectural init.
+ c.floatingPointState = fpu.NewState()
// Require hooks.
if hooks != nil {
diff --git a/pkg/ring0/kernel_amd64.go b/pkg/ring0/kernel_amd64.go
index 4a4c0ae26..7e55011b5 100644
--- a/pkg/ring0/kernel_amd64.go
+++ b/pkg/ring0/kernel_amd64.go
@@ -143,6 +143,9 @@ func (c *CPU) init(cpuID int) {
// Set mandatory flags.
c.registers.Eflags = KernelFlagsSet
+
+ c.hasXSAVE = hasXSAVE
+ c.hasXSAVEOPT = hasXSAVEOPT
}
// StackTop returns the kernel's stack address.
@@ -248,19 +251,21 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
regs.Ss = uint64(Udata) // Ditto.
// Perform the switch.
- swapgs() // GS will be swapped on return.
- WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS.
- LoadFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy in floating point.
+ needIRET := uint64(0)
if switchOpts.FullRestore {
- vector = iret(c, regs, uintptr(userCR3))
- } else {
- vector = sysret(c, regs, uintptr(userCR3))
+ needIRET = 1
}
- SaveFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy out floating point.
- RestoreKernelFPState() // escapes: no. Restore kernel MXCSR.
+ vector = doSwitchToUser(c, regs, switchOpts.FloatingPointState.BytePointer(), userCR3, needIRET) // escapes: no.
return
}
+func doSwitchToUser(
+ cpu *CPU, // +0(FP)
+ regs *arch.Registers, // +8(FP)
+ fpState *byte, // +16(FP)
+ userCR3 uint64, // +24(FP)
+ needIRET uint64) Vector // +32(FP), +40(FP)
+
var (
sentryXCR0 uintptr
sentryXCR0Once sync.Once
@@ -287,7 +292,7 @@ func initSentryXCR0() {
//go:nosplit
func startGo(c *CPU) {
// Save per-cpu.
- WriteGS(kernelAddr(c.kernelEntry))
+ writeGS(kernelAddr(c.kernelEntry))
//
// TODO(mpratt): Note that per the note above, this should be done
diff --git a/pkg/ring0/lib_amd64.go b/pkg/ring0/lib_amd64.go
index 05c394ff5..c42a5b205 100644
--- a/pkg/ring0/lib_amd64.go
+++ b/pkg/ring0/lib_amd64.go
@@ -21,29 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/cpuid"
)
-// LoadFloatingPoint loads floating point state by the most efficient mechanism
-// available (set by Init).
-var LoadFloatingPoint func(*byte)
-
-// SaveFloatingPoint saves floating point state by the most efficient mechanism
-// available (set by Init).
-var SaveFloatingPoint func(*byte)
-
-// fxrstor uses fxrstor64 to load floating point state.
-func fxrstor(*byte)
-
-// xrstor uses xrstor to load floating point state.
-func xrstor(*byte)
-
-// fxsave uses fxsave64 to save floating point state.
-func fxsave(*byte)
-
-// xsave uses xsave to save floating point state.
-func xsave(*byte)
-
-// xsaveopt uses xsaveopt to save floating point state.
-func xsaveopt(*byte)
-
// writeFS sets the FS base address (selects one of wrfsbase or wrfsmsr).
func writeFS(addr uintptr)
@@ -53,8 +30,8 @@ func wrfsbase(addr uintptr)
// wrfsmsr writes to the GS_BASE MSR.
func wrfsmsr(addr uintptr)
-// WriteGS sets the GS address (set by init).
-var WriteGS func(addr uintptr)
+// writeGS sets the GS address (selects one of wrgsbase or wrgsmsr).
+func writeGS(addr uintptr)
// wrgsbase writes to the GS base address.
func wrgsbase(addr uintptr)
@@ -106,19 +83,4 @@ func Init(featureSet *cpuid.FeatureSet) {
hasXSAVE = featureSet.UseXsave()
hasFSGSBASE = featureSet.HasFeature(cpuid.X86FeatureFSGSBase)
validXCR0Mask = uintptr(featureSet.ValidXCR0Mask())
- if hasXSAVEOPT {
- SaveFloatingPoint = xsaveopt
- LoadFloatingPoint = xrstor
- } else if hasXSAVE {
- SaveFloatingPoint = xsave
- LoadFloatingPoint = xrstor
- } else {
- SaveFloatingPoint = fxsave
- LoadFloatingPoint = fxrstor
- }
- if hasFSGSBASE {
- WriteGS = wrgsbase
- } else {
- WriteGS = wrgsmsr
- }
}
diff --git a/pkg/ring0/lib_amd64.s b/pkg/ring0/lib_amd64.s
index 8ed98fc84..0f283aaae 100644
--- a/pkg/ring0/lib_amd64.s
+++ b/pkg/ring0/lib_amd64.s
@@ -128,6 +128,29 @@ TEXT ·wrfsmsr(SB),NOSPLIT,$0-8
BYTE $0x0f; BYTE $0x30;
RET
+// writeGS writes to the GS base.
+//
+// This is written in assembly because it must be callable from assembly (ABI0)
+// without an intermediate transition to ABIInternal.
+//
+// Preconditions: must be running in the lower address space, as it accesses
+// global data.
+TEXT ·writeGS(SB),NOSPLIT,$8-8
+ MOVQ addr+0(FP), AX
+
+ CMPB ·hasFSGSBASE(SB), $1
+ JNE msr
+
+ PUSHQ AX
+ CALL ·wrgsbase(SB)
+ POPQ AX
+ RET
+msr:
+ PUSHQ AX
+ CALL ·wrgsmsr(SB)
+ POPQ AX
+ RET
+
// wrgsbase writes to the GS base.
//
// The code corresponds to:
diff --git a/pkg/ring0/offsets_amd64.go b/pkg/ring0/offsets_amd64.go
index 75f6218b3..38fe27c35 100644
--- a/pkg/ring0/offsets_amd64.go
+++ b/pkg/ring0/offsets_amd64.go
@@ -35,6 +35,9 @@ func Emit(w io.Writer) {
fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer())
fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer())
fmt.Fprintf(w, "#define CPU_ENTRY 0x%02x\n", reflect.ValueOf(&c.kernelEntry).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_HAS_XSAVE 0x%02x\n", reflect.ValueOf(&c.hasXSAVE).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_HAS_XSAVEOPT 0x%02x\n", reflect.ValueOf(&c.hasXSAVEOPT).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_FPU_STATE 0x%02x\n", reflect.ValueOf(&c.floatingPointState).Pointer()-reflect.ValueOf(c).Pointer())
e := &kernelEntry{}
fmt.Fprintf(w, "\n// CPU entry offsets.\n")
diff --git a/pkg/ring0/pagetables/pagetables.go b/pkg/ring0/pagetables/pagetables.go
index 9dac53c80..3f17fba49 100644
--- a/pkg/ring0/pagetables/pagetables.go
+++ b/pkg/ring0/pagetables/pagetables.go
@@ -322,12 +322,3 @@ func (p *PageTables) Lookup(addr hostarch.Addr, findFirst bool) (virtual hostarc
func (p *PageTables) MarkReadOnlyShared() {
p.readOnlyShared = true
}
-
-// PrefaultRootTable touches the root table page to be sure that its physical
-// pages are mapped.
-//
-//go:nosplit
-//go:noinline
-func (p *PageTables) PrefaultRootTable() PTE {
- return p.root[0]
-}
diff --git a/pkg/safecopy/BUILD b/pkg/safecopy/BUILD
index db5787302..2a1602e2b 100644
--- a/pkg/safecopy/BUILD
+++ b/pkg/safecopy/BUILD
@@ -18,8 +18,9 @@ go_library(
],
visibility = ["//:sandbox"],
deps = [
- "//pkg/abi/linux",
- "//pkg/syserror",
+ "//pkg/errors",
+ "//pkg/errors/linuxerr",
+ "//pkg/sighandling",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/safecopy/safecopy.go b/pkg/safecopy/safecopy.go
index df63dd5f1..0dd0aea83 100644
--- a/pkg/safecopy/safecopy.go
+++ b/pkg/safecopy/safecopy.go
@@ -21,7 +21,9 @@ import (
"runtime"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/errors"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/sighandling"
)
// SegvError is returned when a safecopy function receives SIGSEGV.
@@ -82,7 +84,7 @@ var (
// when we get a SIGSEGV that is not interesting to us.
savedSigSegVHandler uintptr
- // same a above, but for SIGBUS signals.
+ // Same as above, but for SIGBUS signals.
savedSigBusHandler uintptr
)
@@ -131,18 +133,18 @@ func initializeAddresses() {
func init() {
initializeAddresses()
- if err := ReplaceSignalHandler(unix.SIGSEGV, addrOfSignalHandler(), &savedSigSegVHandler); err != nil {
+ if err := sighandling.ReplaceSignalHandler(unix.SIGSEGV, addrOfSignalHandler(), &savedSigSegVHandler); err != nil {
panic(fmt.Sprintf("Unable to set handler for SIGSEGV: %v", err))
}
- if err := ReplaceSignalHandler(unix.SIGBUS, addrOfSignalHandler(), &savedSigBusHandler); err != nil {
+ if err := sighandling.ReplaceSignalHandler(unix.SIGBUS, addrOfSignalHandler(), &savedSigBusHandler); err != nil {
panic(fmt.Sprintf("Unable to set handler for SIGBUS: %v", err))
}
- syserror.AddErrorUnwrapper(func(e error) (unix.Errno, bool) {
+ linuxerr.AddErrorUnwrapper(func(e error) (*errors.Error, bool) {
switch e.(type) {
case SegvError, BusError, AlignmentError:
- return unix.EFAULT, true
+ return linuxerr.EFAULT, true
default:
- return 0, false
+ return nil, false
}
})
}
diff --git a/pkg/safecopy/safecopy_unsafe.go b/pkg/safecopy/safecopy_unsafe.go
index 2365b2c0d..15f84abea 100644
--- a/pkg/safecopy/safecopy_unsafe.go
+++ b/pkg/safecopy/safecopy_unsafe.go
@@ -20,7 +20,6 @@ import (
"unsafe"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/abi/linux"
)
// maxRegisterSize is the maximum register size used in memcpy and memclr. It
@@ -332,39 +331,3 @@ func errorFromFaultSignal(addr uintptr, sig int32) error {
panic(fmt.Sprintf("safecopy got unexpected signal %d at address %#x", sig, addr))
}
}
-
-// ReplaceSignalHandler replaces the existing signal handler for the provided
-// signal with the one that handles faults in safecopy-protected functions.
-//
-// It stores the value of the previously set handler in previous.
-//
-// This function will be called on initialization in order to install safecopy
-// handlers for appropriate signals. These handlers will call the previous
-// handler however, and if this is function is being used externally then the
-// same courtesy is expected.
-func ReplaceSignalHandler(sig unix.Signal, handler uintptr, previous *uintptr) error {
- var sa linux.SigAction
- const maskLen = 8
-
- // Get the existing signal handler information, and save the current
- // handler. Once we replace it, we will use this pointer to fall back to
- // it when we receive other signals.
- if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), 0, uintptr(unsafe.Pointer(&sa)), maskLen, 0, 0); e != 0 {
- return e
- }
-
- // Fail if there isn't a previous handler.
- if sa.Handler == 0 {
- return fmt.Errorf("previous handler for signal %x isn't set", sig)
- }
-
- *previous = uintptr(sa.Handler)
-
- // Install our own handler.
- sa.Handler = uint64(handler)
- if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 {
- return e
- }
-
- return nil
-}
diff --git a/pkg/safemem/io.go b/pkg/safemem/io.go
index f039a5c34..9551ca853 100644
--- a/pkg/safemem/io.go
+++ b/pkg/safemem/io.go
@@ -207,58 +207,6 @@ func (r FromIOReader) readToBlock(dst Block, buf []byte) (int, []byte, error) {
return wbn, buf, rerr
}
-// FromIOReaderAt implements Reader for an io.ReaderAt. Does not repeatedly
-// invoke io.ReaderAt.ReadAt because ReadAt is more strict than Read. A partial
-// read indicates an error. This is not thread-safe.
-type FromIOReaderAt struct {
- ReaderAt io.ReaderAt
- Offset int64
-}
-
-// ReadToBlocks implements Reader.ReadToBlocks.
-func (r FromIOReaderAt) ReadToBlocks(dsts BlockSeq) (uint64, error) {
- var buf []byte
- var done uint64
- for !dsts.IsEmpty() {
- dst := dsts.Head()
- var n int
- var err error
- n, buf, err = r.readToBlock(dst, buf)
- done += uint64(n)
- if n != dst.Len() {
- return done, err
- }
- dsts = dsts.Tail()
- if err != nil {
- if dsts.IsEmpty() && err == io.EOF {
- return done, nil
- }
- return done, err
- }
- }
- return done, nil
-}
-
-func (r FromIOReaderAt) readToBlock(dst Block, buf []byte) (int, []byte, error) {
- // io.Reader isn't safecopy-aware, so we have to buffer Blocks that require
- // safecopy.
- if !dst.NeedSafecopy() {
- n, err := r.ReaderAt.ReadAt(dst.ToSlice(), r.Offset)
- r.Offset += int64(n)
- return n, buf, err
- }
- if len(buf) < dst.Len() {
- buf = make([]byte, dst.Len())
- }
- rn, rerr := r.ReaderAt.ReadAt(buf[:dst.Len()], r.Offset)
- r.Offset += int64(rn)
- wbn, wberr := Copy(dst, BlockFromSafeSlice(buf[:rn]))
- if wberr != nil {
- return wbn, buf, wberr
- }
- return wbn, buf, rerr
-}
-
// FromIOWriter implements Writer for an io.Writer by repeatedly invoking
// io.Writer.Write until it returns an error or partial write.
//
diff --git a/pkg/sentry/arch/fpu/BUILD b/pkg/sentry/arch/fpu/BUILD
index 6cdd21b1b..1f371e513 100644
--- a/pkg/sentry/arch/fpu/BUILD
+++ b/pkg/sentry/arch/fpu/BUILD
@@ -9,6 +9,7 @@ go_library(
"fpu_amd64.go",
"fpu_amd64.s",
"fpu_arm64.go",
+ "fpu_unsafe.go",
],
visibility = ["//:sandbox"],
deps = [
diff --git a/pkg/sentry/arch/fpu/fpu.go b/pkg/sentry/arch/fpu/fpu.go
index 867d309a3..62bde19d3 100644
--- a/pkg/sentry/arch/fpu/fpu.go
+++ b/pkg/sentry/arch/fpu/fpu.go
@@ -17,7 +17,6 @@ package fpu
import (
"fmt"
- "reflect"
)
// State represents floating point state.
@@ -40,15 +39,3 @@ type ErrLoadingState struct {
func (e ErrLoadingState) Error() string {
return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supportedFeatures, e.savedFeatures)
}
-
-// alignedBytes returns a slice of size bytes, aligned in memory to the given
-// alignment. This is used because we require certain structures to be aligned
-// in a specific way (for example, the X86 floating point data).
-func alignedBytes(size, alignment uint) []byte {
- data := make([]byte, size+alignment-1)
- offset := uint(reflect.ValueOf(data).Index(0).Addr().Pointer() % uintptr(alignment))
- if offset == 0 {
- return data[:size:size]
- }
- return data[alignment-offset:][:size:size]
-}
diff --git a/pkg/sentry/arch/fpu/fpu_unsafe.go b/pkg/sentry/arch/fpu/fpu_unsafe.go
new file mode 100644
index 000000000..c91dc99be
--- /dev/null
+++ b/pkg/sentry/arch/fpu/fpu_unsafe.go
@@ -0,0 +1,31 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fpu
+
+import (
+ "unsafe"
+)
+
+// alignedBytes returns a slice of size bytes, aligned in memory to the given
+// alignment. This is used because we require certain structures to be aligned
+// in a specific way (for example, the X86 floating point data).
+func alignedBytes(size, alignment uint) []byte {
+ data := make([]byte, size+alignment-1)
+ offset := uint(uintptr(unsafe.Pointer(&data[0])) % uintptr(alignment))
+ if offset == 0 {
+ return data[:size:size]
+ }
+ return data[alignment-offset:][:size:size]
+}
diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD
index 7ee237c9f..cfb33a398 100644
--- a/pkg/sentry/control/BUILD
+++ b/pkg/sentry/control/BUILD
@@ -1,17 +1,25 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test", "proto_library")
package(licenses = ["notice"])
+proto_library(
+ name = "control",
+ srcs = ["control.proto"],
+ visibility = ["//visibility:public"],
+)
+
go_library(
name = "control",
srcs = [
"control.go",
+ "events.go",
"fs.go",
"lifecycle.go",
"logging.go",
"pprof.go",
"proc.go",
"state.go",
+ "usage.go",
],
visibility = [
"//:sandbox",
@@ -19,6 +27,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/eventchannel",
"//pkg/fd",
"//pkg/log",
"//pkg/sentry/fdimport",
@@ -39,6 +48,7 @@ go_library(
"//pkg/tcpip/link/sniffer",
"//pkg/urpc",
"//pkg/usermem",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/control/control.proto b/pkg/sentry/control/control.proto
new file mode 100644
index 000000000..72dda3fbc
--- /dev/null
+++ b/pkg/sentry/control/control.proto
@@ -0,0 +1,40 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package gvisor;
+
+// ControlConfig configures the permission of controls.
+message ControlConfig {
+ // Names for individual control URPC service objects.
+ // Any new service object that should be given conditional access should be
+ // named here and conditionally added based on presence in allowed_controls.
+ enum Endpoint {
+ UNKNOWN = 0;
+ EVENTS = 1;
+ FS = 2;
+ LIFECYCLE = 3;
+ LOGGING = 4;
+ PROFILE = 5;
+ USAGE = 6;
+ PROC = 7;
+ STATE = 8;
+ DEBUG = 9;
+ }
+
+ // allowed_controls represents which endpoints may be registered to the
+ // server.
+ repeated Endpoint allowed_controls = 1;
+}
diff --git a/pkg/sentry/control/events.go b/pkg/sentry/control/events.go
new file mode 100644
index 000000000..92e437ae7
--- /dev/null
+++ b/pkg/sentry/control/events.go
@@ -0,0 +1,65 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package control
+
+import (
+ "errors"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/eventchannel"
+ "gvisor.dev/gvisor/pkg/urpc"
+)
+
+// EventsOpts are the arguments for eventchannel-related commands.
+type EventsOpts struct {
+ urpc.FilePayload
+}
+
+// Events is the control server state for eventchannel-related commands.
+type Events struct {
+ emitter eventchannel.Emitter
+}
+
+// AttachDebugEmitter receives a connected unix domain socket FD from the client
+// and establishes it as a new emitter for the sentry eventchannel. Any existing
+// emitters are replaced on a subsequent attach.
+func (e *Events) AttachDebugEmitter(o *EventsOpts, _ *struct{}) error {
+ if len(o.FilePayload.Files) < 1 {
+ return errors.New("no output writer provided")
+ }
+
+ sock, err := o.ReleaseFD(0)
+ if err != nil {
+ return err
+ }
+ sockFD := sock.Release()
+
+ // SocketEmitter takes ownership of sockFD.
+ emitter, err := eventchannel.SocketEmitter(sockFD)
+ if err != nil {
+ return fmt.Errorf("failed to create SocketEmitter for FD %d: %v", sockFD, err)
+ }
+
+ // If there is already a debug emitter, close the old one.
+ if e.emitter != nil {
+ e.emitter.Close()
+ }
+
+ e.emitter = eventchannel.DebugEmitterFrom(emitter)
+
+ // Register the new stream destination.
+ eventchannel.AddEmitter(e.emitter)
+ return nil
+}
diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go
index 2f3664c57..f721b7236 100644
--- a/pkg/sentry/control/pprof.go
+++ b/pkg/sentry/control/pprof.go
@@ -26,6 +26,23 @@ import (
"gvisor.dev/gvisor/pkg/urpc"
)
+const (
+ // DefaultBlockProfileRate is the default profiling rate for block
+ // profiles.
+ //
+ // The default here is 10%, which will record a stacktrace 10% of the
+ // time when blocking occurs. Since these events should not be super
+ // frequent, we expect this to achieve a reasonable balance between
+ // collecting the data we need and imposing a high performance cost
+ // (e.g. skewing even the CPU profile).
+ DefaultBlockProfileRate = 10
+
+ // DefaultMutexProfileRate is the default profiling rate for mutex
+ // profiles. Like the block rate above, we use a default rate of 10%
+ // for the same reasons.
+ DefaultMutexProfileRate = 10
+)
+
// Profile includes profile-related RPC stubs. It provides a way to
// control the built-in runtime profiling facilities.
//
@@ -175,12 +192,8 @@ func (p *Profile) Block(o *BlockProfileOpts, _ *struct{}) error {
defer p.blockMu.Unlock()
// Always set the rate. We then wait to collect a profile at this rate,
- // and disable when we're done. Note that the default here is 10%, which
- // will record a stacktrace 10% of the time when blocking occurs. Since
- // these events should not be super frequent, we expect this to achieve
- // a reasonable balance between collecting the data we need and imposing
- // a high performance cost (e.g. skewing even the CPU profile).
- rate := 10
+ // and disable when we're done.
+ rate := DefaultBlockProfileRate
if o.Rate != 0 {
rate = o.Rate
}
@@ -220,9 +233,8 @@ func (p *Profile) Mutex(o *MutexProfileOpts, _ *struct{}) error {
p.mutexMu.Lock()
defer p.mutexMu.Unlock()
- // Always set the fraction. Like the block rate above, we use
- // a default rate of 10% for the same reasons.
- fraction := 10
+ // Always set the fraction.
+ fraction := DefaultMutexProfileRate
if o.Fraction != 0 {
fraction = o.Fraction
}
diff --git a/pkg/sentry/control/usage.go b/pkg/sentry/control/usage.go
new file mode 100644
index 000000000..cc78d3f45
--- /dev/null
+++ b/pkg/sentry/control/usage.go
@@ -0,0 +1,183 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package control
+
+import (
+ "fmt"
+ "os"
+ "runtime"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/urpc"
+)
+
+// Usage includes usage-related RPC stubs.
+type Usage struct {
+ Kernel *kernel.Kernel
+}
+
+// MemoryUsageOpts contains usage options.
+type MemoryUsageOpts struct {
+ // Full indicates that a full accounting should be done. If Full is not
+ // specified, then a partial accounting will be done, and Unknown will
+ // contain a majority of memory. See Collect for more information.
+ Full bool `json:"Full"`
+}
+
+// MemoryUsage is a memory usage structure.
+type MemoryUsage struct {
+ Unknown uint64 `json:"Unknown"`
+ System uint64 `json:"System"`
+ Anonymous uint64 `json:"Anonymous"`
+ PageCache uint64 `json:"PageCache"`
+ Mapped uint64 `json:"Mapped"`
+ Tmpfs uint64 `json:"Tmpfs"`
+ Ramdiskfs uint64 `json:"Ramdiskfs"`
+ Total uint64 `json:"Total"`
+}
+
+// MemoryUsageFileOpts contains usage file options.
+type MemoryUsageFileOpts struct {
+ // Version is used to ensure both sides agree on the format of the
+ // shared memory buffer.
+ Version uint64 `json:"Version"`
+}
+
+// MemoryUsageFile contains the file handle to the usage file.
+type MemoryUsageFile struct {
+ urpc.FilePayload
+}
+
+// UsageFD returns the file that tracks the memory usage of the application.
+func (u *Usage) UsageFD(opts *MemoryUsageFileOpts, out *MemoryUsageFile) error {
+ // Only support version 1 for now.
+ if opts.Version != 1 {
+ return fmt.Errorf("unsupported version requested: %d", opts.Version)
+ }
+
+ mf := u.Kernel.MemoryFile()
+ *out = MemoryUsageFile{
+ FilePayload: urpc.FilePayload{
+ Files: []*os.File{
+ usage.MemoryAccounting.File,
+ mf.File(),
+ },
+ },
+ }
+
+ return nil
+}
+
+// Collect returns memory used by the sandboxed application.
+func (u *Usage) Collect(opts *MemoryUsageOpts, out *MemoryUsage) error {
+ if opts.Full {
+ // Ensure everything is up to date.
+ if err := u.Kernel.MemoryFile().UpdateUsage(); err != nil {
+ return err
+ }
+
+ // Copy out a snapshot.
+ snapshot, total := usage.MemoryAccounting.Copy()
+ *out = MemoryUsage{
+ System: snapshot.System,
+ Anonymous: snapshot.Anonymous,
+ PageCache: snapshot.PageCache,
+ Mapped: snapshot.Mapped,
+ Tmpfs: snapshot.Tmpfs,
+ Ramdiskfs: snapshot.Ramdiskfs,
+ Total: total,
+ }
+ } else {
+ // Get total usage from the MemoryFile implementation.
+ total, err := u.Kernel.MemoryFile().TotalUsage()
+ if err != nil {
+ return err
+ }
+
+ // The memory accounting is guaranteed to be accurate only when
+ // UpdateUsage is called. If UpdateUsage is not called, then only Mapped
+ // will be up-to-date.
+ snapshot, _ := usage.MemoryAccounting.Copy()
+ *out = MemoryUsage{
+ Unknown: total,
+ Mapped: snapshot.Mapped,
+ Total: total + snapshot.Mapped,
+ }
+
+ }
+
+ return nil
+}
+
+// UsageReduceOpts contains options to Usage.Reduce().
+type UsageReduceOpts struct {
+ // If Wait is true, Reduce blocks until all activity initiated by
+ // Usage.Reduce() has completed.
+ Wait bool `json:"wait"`
+}
+
+// UsageReduceOutput contains output from Usage.Reduce().
+type UsageReduceOutput struct{}
+
+// Reduce requests that the sentry attempt to reduce its memory usage.
+func (u *Usage) Reduce(opts *UsageReduceOpts, out *UsageReduceOutput) error {
+ mf := u.Kernel.MemoryFile()
+ mf.StartEvictions()
+ if opts.Wait {
+ mf.WaitForEvictions()
+ }
+ return nil
+}
+
+// MemoryUsageRecord contains the mapping and platform memory file.
+type MemoryUsageRecord struct {
+ mmap uintptr
+ stats *usage.RTMemoryStats
+ mf os.File
+}
+
+// NewMemoryUsageRecord creates a new MemoryUsageRecord from usageFile and
+// platformFile.
+func NewMemoryUsageRecord(usageFile, platformFile os.File) (*MemoryUsageRecord, error) {
+ mmap, _, e := unix.RawSyscall6(unix.SYS_MMAP, 0, usage.RTMemoryStatsSize, unix.PROT_READ, unix.MAP_SHARED, usageFile.Fd(), 0)
+ if e != 0 {
+ return nil, fmt.Errorf("mmap returned %d, want 0", e)
+ }
+
+ m := MemoryUsageRecord{
+ mmap: mmap,
+ stats: usage.RTMemoryStatsPointer(mmap),
+ mf: platformFile,
+ }
+
+ runtime.SetFinalizer(&m, finalizer)
+ return &m, nil
+}
+
+func finalizer(m *MemoryUsageRecord) {
+ unix.RawSyscall(unix.SYS_MUNMAP, m.mmap, usage.RTMemoryStatsSize, 0)
+}
+
+// Fetch fetches the usage info from a MemoryUsageRecord.
+func (m *MemoryUsageRecord) Fetch() (mapped, unknown, total uint64, err error) {
+ var stat unix.Stat_t
+ if err := unix.Fstat(int(m.mf.Fd()), &stat); err != nil {
+ return 0, 0, 0, err
+ }
+ fmem := uint64(stat.Blocks) * 512
+ return m.stats.RTMapped, fmem, m.stats.RTMapped + fmem, nil
+}
diff --git a/pkg/sentry/devices/memdev/BUILD b/pkg/sentry/devices/memdev/BUILD
index 4c8604d58..66b9ed523 100644
--- a/pkg/sentry/devices/memdev/BUILD
+++ b/pkg/sentry/devices/memdev/BUILD
@@ -15,6 +15,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/rand",
"//pkg/safemem",
"//pkg/sentry/fsimpl/devtmpfs",
@@ -23,7 +24,6 @@ go_library(
"//pkg/sentry/kernel/auth",
"//pkg/sentry/memmap",
"//pkg/sentry/vfs",
- "//pkg/syserror",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/devices/memdev/full.go b/pkg/sentry/devices/memdev/full.go
index fece3e762..fc702c9f6 100644
--- a/pkg/sentry/devices/memdev/full.go
+++ b/pkg/sentry/devices/memdev/full.go
@@ -16,8 +16,8 @@ package memdev
import (
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -66,12 +66,12 @@ func (fd *fullFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.Rea
// PWrite implements vfs.FileDescriptionImpl.PWrite.
func (fd *fullFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
- return 0, syserror.ENOSPC
+ return 0, linuxerr.ENOSPC
}
// Write implements vfs.FileDescriptionImpl.Write.
func (fd *fullFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
- return 0, syserror.ENOSPC
+ return 0, linuxerr.ENOSPC
}
// Seek implements vfs.FileDescriptionImpl.Seek.
diff --git a/pkg/sentry/devices/quotedev/BUILD b/pkg/sentry/devices/quotedev/BUILD
deleted file mode 100644
index d09214e3e..000000000
--- a/pkg/sentry/devices/quotedev/BUILD
+++ /dev/null
@@ -1,16 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-licenses(["notice"])
-
-go_library(
- name = "quotedev",
- srcs = ["quotedev.go"],
- visibility = ["//pkg/sentry:internal"],
- deps = [
- "//pkg/abi/linux",
- "//pkg/context",
- "//pkg/sentry/fsimpl/devtmpfs",
- "//pkg/sentry/vfs",
- "//pkg/syserror",
- ],
-)
diff --git a/pkg/sentry/devices/quotedev/quotedev.go b/pkg/sentry/devices/quotedev/quotedev.go
deleted file mode 100644
index 6114cb724..000000000
--- a/pkg/sentry/devices/quotedev/quotedev.go
+++ /dev/null
@@ -1,52 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package quotedev implements a vfs.Device for /dev/gvisor_quote.
-package quotedev
-
-import (
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
-)
-
-const (
- quoteDevMinor = 0
-)
-
-// quoteDevice implements vfs.Device for /dev/gvisor_quote
-//
-// +stateify savable
-type quoteDevice struct{}
-
-// Open implements vfs.Device.Open.
-// TODO(b/157161182): Add support for attestation ioctls.
-func (quoteDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- return nil, syserror.EIO
-}
-
-// Register registers all devices implemented by this package in vfsObj.
-func Register(vfsObj *vfs.VirtualFilesystem) error {
- return vfsObj.RegisterDevice(vfs.CharDevice, linux.UNNAMED_MAJOR, quoteDevMinor, quoteDevice{}, &vfs.RegisterDeviceOptions{
- GroupName: "gvisor_quote",
- })
-}
-
-// CreateDevtmpfsFiles creates device special files in dev representing all
-// devices implemented by this package.
-func CreateDevtmpfsFiles(ctx context.Context, dev *devtmpfs.Accessor) error {
- return dev.CreateDeviceFile(ctx, "gvisor_quote", vfs.CharDevice, linux.UNNAMED_MAJOR, quoteDevMinor, 0666 /* mode */)
-}
diff --git a/pkg/sentry/devices/ttydev/BUILD b/pkg/sentry/devices/ttydev/BUILD
index b4b6ca38a..ab4cd0b33 100644
--- a/pkg/sentry/devices/ttydev/BUILD
+++ b/pkg/sentry/devices/ttydev/BUILD
@@ -9,8 +9,8 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/sentry/fsimpl/devtmpfs",
"//pkg/sentry/vfs",
- "//pkg/syserror",
],
)
diff --git a/pkg/sentry/devices/ttydev/ttydev.go b/pkg/sentry/devices/ttydev/ttydev.go
index a287c65ca..29b79b5d6 100644
--- a/pkg/sentry/devices/ttydev/ttydev.go
+++ b/pkg/sentry/devices/ttydev/ttydev.go
@@ -18,9 +18,9 @@ package ttydev
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
)
const (
@@ -36,7 +36,7 @@ type ttyDevice struct{}
// Open implements vfs.Device.Open.
func (ttyDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- return nil, syserror.EIO
+ return nil, linuxerr.EIO
}
// Register registers all devices implemented by this package in vfsObj.
diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD
index 58fe1e77c..4e573d249 100644
--- a/pkg/sentry/fs/BUILD
+++ b/pkg/sentry/fs/BUILD
@@ -68,7 +68,6 @@ go_library(
"//pkg/sentry/usage",
"//pkg/state",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go
index a8591052c..e48bd4dba 100644
--- a/pkg/sentry/fs/copy_up.go
+++ b/pkg/sentry/fs/copy_up.go
@@ -25,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -195,7 +194,7 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error {
attrs, err := next.Inode.overlay.lower.UnstableAttr(ctx)
if err != nil {
log.Warningf("copy up failed to get lower attributes: %v", err)
- return syserror.EIO
+ return linuxerr.EIO
}
var childUpperInode *Inode
@@ -211,7 +210,7 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error {
childFile, err := parentUpper.Create(ctx, root, next.name, FileFlags{Read: true, Write: true}, attrs.Perms)
if err != nil {
log.Warningf("copy up failed to create file: %v", err)
- return syserror.EIO
+ return linuxerr.EIO
}
defer childFile.DecRef(ctx)
childUpperInode = childFile.Dirent.Inode
@@ -219,13 +218,13 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error {
case Directory:
if err := parentUpper.CreateDirectory(ctx, root, next.name, attrs.Perms); err != nil {
log.Warningf("copy up failed to create directory: %v", err)
- return syserror.EIO
+ return linuxerr.EIO
}
childUpper, err := parentUpper.Lookup(ctx, next.name)
if err != nil {
werr := fmt.Errorf("copy up failed to lookup directory: %v", err)
cleanupUpper(ctx, parentUpper, next.name, werr)
- return syserror.EIO
+ return linuxerr.EIO
}
defer childUpper.DecRef(ctx)
childUpperInode = childUpper.Inode
@@ -235,17 +234,17 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error {
link, err := childLower.Readlink(ctx)
if err != nil {
log.Warningf("copy up failed to read symlink value: %v", err)
- return syserror.EIO
+ return linuxerr.EIO
}
if err := parentUpper.CreateLink(ctx, root, link, next.name); err != nil {
log.Warningf("copy up failed to create symlink: %v", err)
- return syserror.EIO
+ return linuxerr.EIO
}
childUpper, err := parentUpper.Lookup(ctx, next.name)
if err != nil {
werr := fmt.Errorf("copy up failed to lookup symlink: %v", err)
cleanupUpper(ctx, parentUpper, next.name, werr)
- return syserror.EIO
+ return linuxerr.EIO
}
defer childUpper.DecRef(ctx)
childUpperInode = childUpper.Inode
@@ -259,14 +258,14 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error {
if err := copyAttributesLocked(ctx, childUpperInode, next.Inode.overlay.lower); err != nil {
werr := fmt.Errorf("copy up failed to copy up attributes: %v", err)
cleanupUpper(ctx, parentUpper, next.name, werr)
- return syserror.EIO
+ return linuxerr.EIO
}
// Copy the entire file.
if err := copyContentsLocked(ctx, childUpperInode, next.Inode.overlay.lower, attrs.Size); err != nil {
werr := fmt.Errorf("copy up failed to copy up contents: %v", err)
cleanupUpper(ctx, parentUpper, next.name, werr)
- return syserror.EIO
+ return linuxerr.EIO
}
lowerMappable := next.Inode.overlay.lower.Mappable()
@@ -274,7 +273,7 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error {
if lowerMappable != nil && upperMappable == nil {
werr := fmt.Errorf("copy up failed: cannot ensure memory mapping coherence")
cleanupUpper(ctx, parentUpper, next.name, werr)
- return syserror.EIO
+ return linuxerr.EIO
}
// Propagate memory mappings to the upper Inode.
diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD
index e28a8961b..7baf26b24 100644
--- a/pkg/sentry/fs/dev/BUILD
+++ b/pkg/sentry/fs/dev/BUILD
@@ -34,7 +34,6 @@ go_library(
"//pkg/sentry/mm",
"//pkg/sentry/pgalloc",
"//pkg/sentry/socket/netstack",
- "//pkg/syserror",
"//pkg/tcpip/link/tun",
"//pkg/usermem",
"//pkg/waiter",
diff --git a/pkg/sentry/fs/dev/full.go b/pkg/sentry/fs/dev/full.go
index deb9c6ad8..6f0c1fc68 100644
--- a/pkg/sentry/fs/dev/full.go
+++ b/pkg/sentry/fs/dev/full.go
@@ -17,9 +17,9 @@ package dev
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -77,5 +77,5 @@ var _ fs.FileOperations = (*fullFileOperations)(nil)
// Write implements FileOperations.Write.
func (*fullFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
- return 0, syserror.ENOSPC
+ return 0, linuxerr.ENOSPC
}
diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go
index ad8ff227e..d300a32e0 100644
--- a/pkg/sentry/fs/dirent.go
+++ b/pkg/sentry/fs/dirent.go
@@ -28,7 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
type globalDirentMap struct {
@@ -963,7 +962,7 @@ func (d *Dirent) isMountPointLocked() bool {
func (d *Dirent) mount(ctx context.Context, inode *Inode) (newChild *Dirent, err error) {
// Did we race with deletion?
if atomic.LoadInt32(&d.deleted) != 0 {
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
// Refuse to mount a symlink.
@@ -998,7 +997,7 @@ func (d *Dirent) mount(ctx context.Context, inode *Inode) (newChild *Dirent, err
func (d *Dirent) unmount(ctx context.Context, replacement *Dirent) error {
// Did we race with deletion?
if atomic.LoadInt32(&d.deleted) != 0 {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
// Remount our former child in its place.
diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD
index 5c889c861..9f1fe5160 100644
--- a/pkg/sentry/fs/fdpipe/BUILD
+++ b/pkg/sentry/fs/fdpipe/BUILD
@@ -22,7 +22,6 @@ go_library(
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
"@org_golang_x_sys//unix:go_default_library",
@@ -46,7 +45,6 @@ go_test(
"//pkg/hostarch",
"//pkg/sentry/contexttest",
"//pkg/sentry/fs",
- "//pkg/syserror",
"//pkg/usermem",
"@com_github_google_uuid//:go_default_library",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/sentry/fs/fdpipe/pipe.go b/pkg/sentry/fs/fdpipe/pipe.go
index f8a29816b..4370cce33 100644
--- a/pkg/sentry/fs/fdpipe/pipe.go
+++ b/pkg/sentry/fs/fdpipe/pipe.go
@@ -29,7 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -142,7 +141,7 @@ func (p *pipeOperations) Read(ctx context.Context, file *fs.File, dst usermem.IO
n, err := dst.CopyOutFrom(ctx, safemem.FromIOReader{secio.FullReader{p.file}})
total := int64(bufN) + n
if err != nil && isBlockError(err) {
- return total, syserror.ErrWouldBlock
+ return total, linuxerr.ErrWouldBlock
}
return total, err
}
@@ -151,13 +150,13 @@ func (p *pipeOperations) Read(ctx context.Context, file *fs.File, dst usermem.IO
func (p *pipeOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
n, err := src.CopyInTo(ctx, safemem.FromIOWriter{p.file})
if err != nil && isBlockError(err) {
- return n, syserror.ErrWouldBlock
+ return n, linuxerr.ErrWouldBlock
}
return n, err
}
// isBlockError unwraps os errors and checks if they are caused by EAGAIN or
-// EWOULDBLOCK. This is so they can be transformed into syserror.ErrWouldBlock.
+// EWOULDBLOCK. This is so they can be transformed into linuxerr.ErrWouldBlock.
func isBlockError(err error) bool {
if linuxerr.Equals(linuxerr.EAGAIN, err) || linuxerr.Equals(linuxerr.EWOULDBLOCK, err) {
return true
diff --git a/pkg/sentry/fs/fdpipe/pipe_opener.go b/pkg/sentry/fs/fdpipe/pipe_opener.go
index adda19168..e91e1b5cb 100644
--- a/pkg/sentry/fs/fdpipe/pipe_opener.go
+++ b/pkg/sentry/fs/fdpipe/pipe_opener.go
@@ -21,9 +21,9 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/syserror"
)
// NonBlockingOpener is a generic host file opener used to retry opening host
@@ -40,7 +40,7 @@ func Open(ctx context.Context, opener NonBlockingOpener, flags fs.FileFlags) (fs
p := &pipeOpenState{}
canceled := false
for {
- if file, err := p.TryOpen(ctx, opener, flags); err != syserror.ErrWouldBlock {
+ if file, err := p.TryOpen(ctx, opener, flags); err != linuxerr.ErrWouldBlock {
return file, err
}
@@ -51,7 +51,7 @@ func Open(ctx context.Context, opener NonBlockingOpener, flags fs.FileFlags) (fs
if p.hostFile != nil {
p.hostFile.Close()
}
- return nil, syserror.ErrInterrupted
+ return nil, linuxerr.ErrInterrupted
}
cancel := ctx.SleepStart()
@@ -106,13 +106,13 @@ func (p *pipeOpenState) TryOpen(ctx context.Context, opener NonBlockingOpener, f
}
return newPipeOperations(ctx, opener, flags, f, nil)
- // Handle opening O_WRONLY blocking: convert ENXIO to syserror.ErrWouldBlock.
+ // Handle opening O_WRONLY blocking: convert ENXIO to linuxerr.ErrWouldBlock.
// See TryOpenWriteOnly for more details.
case flags.Write:
return p.TryOpenWriteOnly(ctx, opener)
default:
- // Handle opening O_RDONLY blocking: convert EOF from read to syserror.ErrWouldBlock.
+ // Handle opening O_RDONLY blocking: convert EOF from read to linuxerr.ErrWouldBlock.
// See TryOpenReadOnly for more details.
return p.TryOpenReadOnly(ctx, opener)
}
@@ -120,7 +120,7 @@ func (p *pipeOpenState) TryOpen(ctx context.Context, opener NonBlockingOpener, f
// TryOpenReadOnly tries to open a host pipe read only but only returns a fs.File when
// there is a coordinating writer. Call TryOpenReadOnly repeatedly on the same pipeOpenState
-// until syserror.ErrWouldBlock is no longer returned.
+// until linuxerr.ErrWouldBlock is no longer returned.
//
// How it works:
//
@@ -150,7 +150,7 @@ func (p *pipeOpenState) TryOpenReadOnly(ctx context.Context, opener NonBlockingO
if n == 0 {
// EOF means that we're not ready yet.
if rerr == nil || rerr == io.EOF {
- return nil, syserror.ErrWouldBlock
+ return nil, linuxerr.ErrWouldBlock
}
// Any error that is not EWOULDBLOCK also means we're not
// ready yet, and probably never will be ready. In this
@@ -175,16 +175,16 @@ func (p *pipeOpenState) TryOpenReadOnly(ctx context.Context, opener NonBlockingO
// TryOpenWriteOnly tries to open a host pipe write only but only returns a fs.File when
// there is a coordinating reader. Call TryOpenWriteOnly repeatedly on the same pipeOpenState
-// until syserror.ErrWouldBlock is no longer returned.
+// until linuxerr.ErrWouldBlock is no longer returned.
//
// How it works:
//
// Opening a pipe write only will return ENXIO until readers are available. Converts the ENXIO
-// to an syserror.ErrWouldBlock, to tell callers to retry.
+// to an linuxerr.ErrWouldBlock, to tell callers to retry.
func (*pipeOpenState) TryOpenWriteOnly(ctx context.Context, opener NonBlockingOpener) (*pipeOperations, error) {
hostFile, err := opener.NonBlockingOpen(ctx, fs.PermMask{Write: true})
if unwrapError(err) == unix.ENXIO {
- return nil, syserror.ErrWouldBlock
+ return nil, linuxerr.ErrWouldBlock
}
if err != nil {
return nil, err
diff --git a/pkg/sentry/fs/fdpipe/pipe_opener_test.go b/pkg/sentry/fs/fdpipe/pipe_opener_test.go
index 89d8be741..e1587288e 100644
--- a/pkg/sentry/fs/fdpipe/pipe_opener_test.go
+++ b/pkg/sentry/fs/fdpipe/pipe_opener_test.go
@@ -30,7 +30,6 @@ import (
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -146,18 +145,18 @@ func TestTryOpen(t *testing.T) {
err: unix.ENOENT,
},
{
- desc: "Blocking Write only returns with syserror.ErrWouldBlock",
+ desc: "Blocking Write only returns with linuxerr.ErrWouldBlock",
makePipe: true,
flags: fs.FileFlags{Write: true},
expectFile: false,
- err: syserror.ErrWouldBlock,
+ err: linuxerr.ErrWouldBlock,
},
{
- desc: "Blocking Read only returns with syserror.ErrWouldBlock",
+ desc: "Blocking Read only returns with linuxerr.ErrWouldBlock",
makePipe: true,
flags: fs.FileFlags{Read: true},
expectFile: false,
- err: syserror.ErrWouldBlock,
+ err: linuxerr.ErrWouldBlock,
},
} {
name := pipename()
@@ -316,7 +315,7 @@ func TestCopiedReadAheadBuffer(t *testing.T) {
// another writer comes along. This means we can open the same pipe write only
// with no problems + write to it, given that opener.Open already tried to open
// the pipe RDONLY and succeeded, which we know happened if TryOpen returns
- // syserror.ErrwouldBlock.
+ // linuxerr.ErrwouldBlock.
//
// This simulates the open(RDONLY) <-> open(WRONLY)+write race we care about, but
// does not cause our test to be racy (which would be terrible).
@@ -328,8 +327,8 @@ func TestCopiedReadAheadBuffer(t *testing.T) {
pipeOps.Release(ctx)
t.Fatalf("open(%s, %o) got file, want nil", name, unix.O_RDONLY)
}
- if err != syserror.ErrWouldBlock {
- t.Fatalf("open(%s, %o) got error %v, want %v", name, unix.O_RDONLY, err, syserror.ErrWouldBlock)
+ if err != linuxerr.ErrWouldBlock {
+ t.Fatalf("open(%s, %o) got error %v, want %v", name, unix.O_RDONLY, err, linuxerr.ErrWouldBlock)
}
// Then open the same pipe write only and write some bytes to it. The next
diff --git a/pkg/sentry/fs/fdpipe/pipe_test.go b/pkg/sentry/fs/fdpipe/pipe_test.go
index 4c8905a7e..63900e766 100644
--- a/pkg/sentry/fs/fdpipe/pipe_test.go
+++ b/pkg/sentry/fs/fdpipe/pipe_test.go
@@ -28,7 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -238,7 +237,7 @@ func TestPipeRequest(t *testing.T) {
context: &Readv{Dst: usermem.BytesIOSequence(make([]byte, 10))},
flags: fs.FileFlags{Read: true},
keepOpenPartner: true,
- err: syserror.ErrWouldBlock,
+ err: linuxerr.ErrWouldBlock,
},
{
desc: "Writev on pipe from empty buffer returns nil",
@@ -410,8 +409,8 @@ func TestPipeReadsAccumulate(t *testing.T) {
n, err := p.Read(ctx, file, iov, 0)
total := n
iov = iov.DropFirst64(n)
- if err != syserror.ErrWouldBlock {
- t.Fatalf("Readv got error %v, want %v", err, syserror.ErrWouldBlock)
+ if err != linuxerr.ErrWouldBlock {
+ t.Fatalf("Readv got error %v, want %v", err, linuxerr.ErrWouldBlock)
}
// Write a few more bytes to allow us to read more/accumulate.
@@ -479,8 +478,8 @@ func TestPipeWritesAccumulate(t *testing.T) {
}
iov := usermem.BytesIOSequence(writeBuffer)
n, err := p.Write(ctx, file, iov, 0)
- if err != syserror.ErrWouldBlock {
- t.Fatalf("Writev got error %v, want %v", err, syserror.ErrWouldBlock)
+ if err != linuxerr.ErrWouldBlock {
+ t.Fatalf("Writev got error %v, want %v", err, linuxerr.ErrWouldBlock)
}
if n != int64(pipeSize) {
t.Fatalf("Writev partial write, got: %v, want %v", n, pipeSize)
diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go
index 57f904801..df04f044d 100644
--- a/pkg/sentry/fs/file.go
+++ b/pkg/sentry/fs/file.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsmetric"
@@ -27,7 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -195,10 +195,10 @@ func (f *File) EventUnregister(e *waiter.Entry) {
// offset to the value returned by f.FileOperations.Seek if the operation
// is successful.
//
-// Returns syserror.ErrInterrupted if seeking was interrupted.
+// Returns linuxerr.ErrInterrupted if seeking was interrupted.
func (f *File) Seek(ctx context.Context, whence SeekWhence, offset int64) (int64, error) {
if !f.mu.Lock(ctx) {
- return 0, syserror.ErrInterrupted
+ return 0, linuxerr.ErrInterrupted
}
defer f.mu.Unlock()
@@ -217,10 +217,10 @@ func (f *File) Seek(ctx context.Context, whence SeekWhence, offset int64) (int64
// Readdir unconditionally updates the access time on the File's Inode,
// see fs/readdir.c:iterate_dir.
//
-// Returns syserror.ErrInterrupted if reading was interrupted.
+// Returns linuxerr.ErrInterrupted if reading was interrupted.
func (f *File) Readdir(ctx context.Context, serializer DentrySerializer) error {
if !f.mu.Lock(ctx) {
- return syserror.ErrInterrupted
+ return linuxerr.ErrInterrupted
}
defer f.mu.Unlock()
@@ -232,13 +232,13 @@ func (f *File) Readdir(ctx context.Context, serializer DentrySerializer) error {
// Readv calls f.FileOperations.Read with f as the File, advancing the file
// offset if f.FileOperations.Read returns bytes read > 0.
//
-// Returns syserror.ErrInterrupted if reading was interrupted.
+// Returns linuxerr.ErrInterrupted if reading was interrupted.
func (f *File) Readv(ctx context.Context, dst usermem.IOSequence) (int64, error) {
start := fsmetric.StartReadWait()
defer fsmetric.FinishReadWait(fsmetric.ReadWait, start)
if !f.mu.Lock(ctx) {
- return 0, syserror.ErrInterrupted
+ return 0, linuxerr.ErrInterrupted
}
fsmetric.Reads.Increment()
@@ -260,7 +260,7 @@ func (f *File) Preadv(ctx context.Context, dst usermem.IOSequence, offset int64)
defer fsmetric.FinishReadWait(fsmetric.ReadWait, start)
if !f.mu.Lock(ctx) {
- return 0, syserror.ErrInterrupted
+ return 0, linuxerr.ErrInterrupted
}
fsmetric.Reads.Increment()
@@ -276,10 +276,10 @@ func (f *File) Preadv(ctx context.Context, dst usermem.IOSequence, offset int64)
// unavoidably racy for network file systems. Writev also truncates src
// to avoid overrunning the current file size limit if necessary.
//
-// Returns syserror.ErrInterrupted if writing was interrupted.
+// Returns linuxerr.ErrInterrupted if writing was interrupted.
func (f *File) Writev(ctx context.Context, src usermem.IOSequence) (int64, error) {
if !f.mu.Lock(ctx) {
- return 0, syserror.ErrInterrupted
+ return 0, linuxerr.ErrInterrupted
}
unlockAppendMu := f.Dirent.Inode.lockAppendMu(f.Flags().Append)
// Handle append mode.
@@ -297,7 +297,7 @@ func (f *File) Writev(ctx context.Context, src usermem.IOSequence) (int64, error
case ok && limit == 0:
unlockAppendMu()
f.mu.Unlock()
- return 0, syserror.ErrExceedsFileSizeLimit
+ return 0, linuxerr.ErrExceedsFileSizeLimit
case ok:
src = src.TakeFirst64(limit)
}
@@ -335,7 +335,7 @@ func (f *File) Pwritev(ctx context.Context, src usermem.IOSequence, offset int64
limit, ok := f.checkLimit(ctx, offset)
switch {
case ok && limit == 0:
- return 0, syserror.ErrExceedsFileSizeLimit
+ return 0, linuxerr.ErrExceedsFileSizeLimit
case ok:
src = src.TakeFirst64(limit)
}
@@ -352,7 +352,7 @@ func (f *File) offsetForAppend(ctx context.Context, offset *int64) error {
if err != nil {
// This is an odd error, we treat it as evidence that
// something is terribly wrong with the filesystem.
- return syserror.EIO
+ return linuxerr.EIO
}
// Update the offset.
@@ -381,10 +381,10 @@ func (f *File) checkLimit(ctx context.Context, offset int64) (int64, bool) {
// Fsync calls f.FileOperations.Fsync with f as the File.
//
-// Returns syserror.ErrInterrupted if syncing was interrupted.
+// Returns linuxerr.ErrInterrupted if syncing was interrupted.
func (f *File) Fsync(ctx context.Context, start int64, end int64, syncType SyncType) error {
if !f.mu.Lock(ctx) {
- return syserror.ErrInterrupted
+ return linuxerr.ErrInterrupted
}
defer f.mu.Unlock()
@@ -393,10 +393,10 @@ func (f *File) Fsync(ctx context.Context, start int64, end int64, syncType SyncT
// Flush calls f.FileOperations.Flush with f as the File.
//
-// Returns syserror.ErrInterrupted if syncing was interrupted.
+// Returns linuxerr.ErrInterrupted if syncing was interrupted.
func (f *File) Flush(ctx context.Context) error {
if !f.mu.Lock(ctx) {
- return syserror.ErrInterrupted
+ return linuxerr.ErrInterrupted
}
defer f.mu.Unlock()
@@ -405,10 +405,10 @@ func (f *File) Flush(ctx context.Context) error {
// ConfigureMMap calls f.FileOperations.ConfigureMMap with f as the File.
//
-// Returns syserror.ErrInterrupted if interrupted.
+// Returns linuxerr.ErrInterrupted if interrupted.
func (f *File) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
if !f.mu.Lock(ctx) {
- return syserror.ErrInterrupted
+ return linuxerr.ErrInterrupted
}
defer f.mu.Unlock()
@@ -417,10 +417,10 @@ func (f *File) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
// UnstableAttr calls f.FileOperations.UnstableAttr with f as the File.
//
-// Returns syserror.ErrInterrupted if interrupted.
+// Returns linuxerr.ErrInterrupted if interrupted.
func (f *File) UnstableAttr(ctx context.Context) (UnstableAttr, error) {
if !f.mu.Lock(ctx) {
- return UnstableAttr{}, syserror.ErrInterrupted
+ return UnstableAttr{}, linuxerr.ErrInterrupted
}
defer f.mu.Unlock()
@@ -495,7 +495,7 @@ type lockedReader struct {
// Read implements io.Reader.Read.
func (r *lockedReader) Read(buf []byte) (int, error) {
if r.Ctx.Interrupted() {
- return 0, syserror.ErrInterrupted
+ return 0, linuxerr.ErrInterrupted
}
n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.Offset)
r.Offset += n
@@ -505,7 +505,7 @@ func (r *lockedReader) Read(buf []byte) (int, error) {
// ReadAt implements io.Reader.ReadAt.
func (r *lockedReader) ReadAt(buf []byte, offset int64) (int, error) {
if r.Ctx.Interrupted() {
- return 0, syserror.ErrInterrupted
+ return 0, linuxerr.ErrInterrupted
}
n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), offset)
return int(n), err
@@ -530,7 +530,7 @@ type lockedWriter struct {
// Write implements io.Writer.Write.
func (w *lockedWriter) Write(buf []byte) (int, error) {
if w.Ctx.Interrupted() {
- return 0, syserror.ErrInterrupted
+ return 0, linuxerr.ErrInterrupted
}
n, err := w.WriteAt(buf, w.Offset)
w.Offset += int64(n)
@@ -549,7 +549,7 @@ func (w *lockedWriter) WriteAt(buf []byte, offset int64) (int, error) {
// contract. Enforce that here.
for written < len(buf) {
if w.Ctx.Interrupted() {
- return written, syserror.ErrInterrupted
+ return written, linuxerr.ErrInterrupted
}
var n int64
n, err = w.File.FileOperations.Write(w.Ctx, w.File, usermem.BytesIOSequence(buf[written:]), offset+int64(written))
diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go
index 6ec721022..ce47c3907 100644
--- a/pkg/sentry/fs/file_operations.go
+++ b/pkg/sentry/fs/file_operations.go
@@ -120,7 +120,7 @@ type FileOperations interface {
// Files with !FileFlags.Pwrite.
//
// If only part of src could be written, Write must return an error
- // indicating why (e.g. syserror.ErrWouldBlock).
+ // indicating why (e.g. linuxerr.ErrWouldBlock).
//
// Write does not check permissions nor flags.
//
diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go
index 06c07c807..a27dd0b9a 100644
--- a/pkg/sentry/fs/file_overlay.go
+++ b/pkg/sentry/fs/file_overlay.go
@@ -16,6 +16,7 @@ package fs
import (
"io"
+ "math"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
@@ -23,7 +24,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -246,7 +246,7 @@ func (f *overlayFileOperations) onTop(ctx context.Context, file *File, fn func(*
// Something very wrong; return a generic filesystem
// error to avoid propagating internals.
f.upperMu.Unlock()
- return syserror.EIO
+ return linuxerr.EIO
}
// Save upper file.
@@ -361,10 +361,13 @@ func (*overlayFileOperations) ConfigureMMap(ctx context.Context, file *File, opt
return linuxerr.ENODEV
}
- // FIXME(jamieliu): This is a copy/paste of fsutil.GenericConfigureMMap,
- // which we can't use because the overlay implementation is in package fs,
- // so depending on fs/fsutil would create a circular dependency. Move
- // overlay to fs/overlay.
+ // TODO(gvisor.dev/issue/1624): This is a copy/paste of
+ // fsutil.GenericConfigureMMap, which we can't use because the overlay
+ // implementation is in package fs, so depending on fs/fsutil would create
+ // a circular dependency. VFS2 overlay doesn't have this issue.
+ if opts.Offset+opts.Length > math.MaxInt64 {
+ return linuxerr.EOVERFLOW
+ }
opts.Mappable = o
opts.MappingIdentity = file
file.IncRef()
diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD
index 6bf2d51cb..1a59800ea 100644
--- a/pkg/sentry/fs/fsutil/BUILD
+++ b/pkg/sentry/fs/fsutil/BUILD
@@ -90,7 +90,6 @@ go_library(
"//pkg/sentry/usage",
"//pkg/state",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/sentry/fs/fsutil/file.go b/pkg/sentry/fs/fsutil/file.go
index 00b3bb29b..38e3ed42d 100644
--- a/pkg/sentry/fs/fsutil/file.go
+++ b/pkg/sentry/fs/fsutil/file.go
@@ -16,13 +16,13 @@ package fsutil
import (
"io"
+ "math"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -211,6 +211,9 @@ func (FileNoMMap) ConfigureMMap(context.Context, *fs.File, *memmap.MMapOpts) err
// GenericConfigureMMap implements fs.FileOperations.ConfigureMMap for most
// filesystems that support memory mapping.
func GenericConfigureMMap(file *fs.File, m memmap.Mappable, opts *memmap.MMapOpts) error {
+ if opts.Offset+opts.Length > math.MaxInt64 {
+ return linuxerr.EOVERFLOW
+ }
opts.Mappable = m
opts.MappingIdentity = file
file.IncRef()
@@ -232,12 +235,12 @@ type FileNoSplice struct{}
// WriteTo implements fs.FileOperations.WriteTo.
func (FileNoSplice) WriteTo(context.Context, *fs.File, io.Writer, int64, bool) (int64, error) {
- return 0, syserror.ENOSYS
+ return 0, linuxerr.ENOSYS
}
// ReadFrom implements fs.FileOperations.ReadFrom.
func (FileNoSplice) ReadFrom(context.Context, *fs.File, io.Reader, int64) (int64, error) {
- return 0, syserror.ENOSYS
+ return 0, linuxerr.ENOSYS
}
// DirFileOperations implements most of fs.FileOperations for directories,
@@ -255,12 +258,12 @@ type DirFileOperations struct {
// Read implements fs.FileOperations.Read
func (*DirFileOperations) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
- return 0, syserror.EISDIR
+ return 0, linuxerr.EISDIR
}
// Write implements fs.FileOperations.Write.
func (*DirFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
- return 0, syserror.EISDIR
+ return 0, linuxerr.EISDIR
}
// StaticDirFileOperations implements fs.FileOperations for directories with
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go
index 23528bf25..37ddb1a3c 100644
--- a/pkg/sentry/fs/fsutil/host_file_mapper.go
+++ b/pkg/sentry/fs/fsutil/host_file_mapper.go
@@ -93,7 +93,8 @@ func NewHostFileMapper() *HostFileMapper {
func (f *HostFileMapper) IncRefOn(mr memmap.MappableRange) {
f.refsMu.Lock()
defer f.refsMu.Unlock()
- for chunkStart := mr.Start &^ chunkMask; chunkStart < mr.End; chunkStart += chunkSize {
+ chunkStart := mr.Start &^ chunkMask
+ for {
refs := f.refs[chunkStart]
pgs := pagesInChunk(mr, chunkStart)
if refs+pgs < refs {
@@ -101,6 +102,10 @@ func (f *HostFileMapper) IncRefOn(mr memmap.MappableRange) {
panic(fmt.Sprintf("HostFileMapper.IncRefOn(%v): adding %d page references to chunk %#x, which has %d page references", mr, pgs, chunkStart, refs))
}
f.refs[chunkStart] = refs + pgs
+ chunkStart += chunkSize
+ if chunkStart >= mr.End || chunkStart == 0 {
+ break
+ }
}
}
@@ -112,7 +117,8 @@ func (f *HostFileMapper) IncRefOn(mr memmap.MappableRange) {
func (f *HostFileMapper) DecRefOn(mr memmap.MappableRange) {
f.refsMu.Lock()
defer f.refsMu.Unlock()
- for chunkStart := mr.Start &^ chunkMask; chunkStart < mr.End; chunkStart += chunkSize {
+ chunkStart := mr.Start &^ chunkMask
+ for {
refs := f.refs[chunkStart]
pgs := pagesInChunk(mr, chunkStart)
switch {
@@ -128,6 +134,10 @@ func (f *HostFileMapper) DecRefOn(mr memmap.MappableRange) {
case refs < pgs:
panic(fmt.Sprintf("HostFileMapper.DecRefOn(%v): removing %d page references from chunk %#x, which has %d page references", mr, pgs, chunkStart, refs))
}
+ chunkStart += chunkSize
+ if chunkStart >= mr.End || chunkStart == 0 {
+ break
+ }
}
}
@@ -161,7 +171,8 @@ func (f *HostFileMapper) forEachMappingBlockLocked(fr memmap.FileRange, fd int,
if write {
prot |= unix.PROT_WRITE
}
- for chunkStart := fr.Start &^ chunkMask; chunkStart < fr.End; chunkStart += chunkSize {
+ chunkStart := fr.Start &^ chunkMask
+ for {
m, ok := f.mappings[chunkStart]
if !ok {
addr, _, errno := unix.Syscall6(
@@ -201,6 +212,10 @@ func (f *HostFileMapper) forEachMappingBlockLocked(fr memmap.FileRange, fd int,
endOff = fr.End - chunkStart
}
fn(f.unsafeBlockFromChunkMapping(m.addr).TakeFirst64(endOff).DropFirst64(startOff))
+ chunkStart += chunkSize
+ if chunkStart >= fr.End || chunkStart == 0 {
+ break
+ }
}
return nil
}
diff --git a/pkg/sentry/fs/fsutil/inode.go b/pkg/sentry/fs/fsutil/inode.go
index 7c2de04c1..06a994193 100644
--- a/pkg/sentry/fs/fsutil/inode.go
+++ b/pkg/sentry/fs/fsutil/inode.go
@@ -23,7 +23,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -167,7 +166,7 @@ func (i *InodeSimpleAttributes) DropLink() {
// StatFS implements fs.InodeOperations.StatFS.
func (i *InodeSimpleAttributes) StatFS(context.Context) (fs.Info, error) {
if i.fsType == 0 {
- return fs.Info{}, syserror.ENOSYS
+ return fs.Info{}, linuxerr.ENOSYS
}
return fs.Info{Type: i.fsType}, nil
}
@@ -294,7 +293,7 @@ type InodeNoStatFS struct{}
// StatFS implements fs.InodeOperations.StatFS.
func (InodeNoStatFS) StatFS(context.Context) (fs.Info, error) {
- return fs.Info{}, syserror.ENOSYS
+ return fs.Info{}, linuxerr.ENOSYS
}
// InodeStaticFileGetter implements GetFile for a file with static contents.
@@ -401,7 +400,7 @@ type InodeIsDirTruncate struct{}
// Truncate implements fs.InodeOperations.Truncate.
func (InodeIsDirTruncate) Truncate(context.Context, *fs.Inode, int64) error {
- return syserror.EISDIR
+ return linuxerr.EISDIR
}
// InodeNoopTruncate implements fs.InodeOperations.Truncate as a noop.
@@ -425,7 +424,7 @@ type InodeNotOpenable struct{}
// GetFile implements fs.InodeOperations.GetFile.
func (InodeNotOpenable) GetFile(context.Context, *fs.Dirent, fs.FileFlags) (*fs.File, error) {
- return nil, syserror.EIO
+ return nil, linuxerr.EIO
}
// InodeNotVirtual can be used by Inodes that are not virtual.
@@ -529,5 +528,5 @@ type InodeIsDirAllocate struct{}
// Allocate implements fs.InodeOperations.Allocate.
func (InodeIsDirAllocate) Allocate(_ context.Context, _ *fs.Inode, _, _ int64) error {
- return syserror.EISDIR
+ return linuxerr.EISDIR
}
diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD
index c08301d19..ee2f287d9 100644
--- a/pkg/sentry/fs/gofer/BUILD
+++ b/pkg/sentry/fs/gofer/BUILD
@@ -26,6 +26,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors",
"//pkg/errors/linuxerr",
"//pkg/fd",
"//pkg/hostarch",
@@ -48,7 +49,6 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sync",
"//pkg/syserr",
- "//pkg/syserror",
"//pkg/unet",
"//pkg/usermem",
"//pkg/waiter",
@@ -63,10 +63,10 @@ go_test(
library = ":gofer",
deps = [
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/p9",
"//pkg/p9/p9test",
"//pkg/sentry/contexttest",
"//pkg/sentry/fs",
- "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go
index 73d80d9b5..62a517cd7 100644
--- a/pkg/sentry/fs/gofer/file.go
+++ b/pkg/sentry/fs/gofer/file.go
@@ -20,6 +20,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/p9"
@@ -28,7 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fsmetric"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -226,7 +226,7 @@ func (f *fileOperations) maybeSync(ctx context.Context, file *fs.File, offset, n
func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
if fs.IsDir(file.Dirent.Inode.StableAttr) {
// Not all remote file systems enforce this so this client does.
- return 0, syserror.EISDIR
+ return 0, linuxerr.EISDIR
}
var (
@@ -294,7 +294,7 @@ func (f *fileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IO
if fs.IsDir(file.Dirent.Inode.StableAttr) {
// Not all remote file systems enforce this so this client does.
f.incrementReadCounters(start)
- return 0, syserror.EISDIR
+ return 0, linuxerr.EISDIR
}
if f.inodeOperations.session().cachePolicy.useCachingInodeOps(file.Dirent.Inode) {
diff --git a/pkg/sentry/fs/gofer/gofer_test.go b/pkg/sentry/fs/gofer/gofer_test.go
index 546ee7d04..4924debeb 100644
--- a/pkg/sentry/fs/gofer/gofer_test.go
+++ b/pkg/sentry/fs/gofer/gofer_test.go
@@ -19,8 +19,8 @@ import (
"testing"
"time"
- "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/p9/p9test"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
@@ -97,7 +97,7 @@ func TestLookup(t *testing.T) {
},
{
name: "mock Walk fails (function fails)",
- want: unix.ENOENT,
+ want: linuxerr.ENOENT,
},
}
@@ -123,7 +123,7 @@ func TestLookup(t *testing.T) {
var newInodeOperations fs.InodeOperations
if dirent != nil {
if dirent.IsNegative() {
- err = unix.ENOENT
+ err = linuxerr.ENOENT
} else {
newInodeOperations = dirent.Inode.InodeOperations
}
@@ -131,9 +131,11 @@ func TestLookup(t *testing.T) {
// Check return values.
if err != test.want {
+ t.Logf("err: %v %T", err, err)
t.Errorf("Lookup got err %v, want %v", err, test.want)
}
if err == nil && newInodeOperations == nil {
+ t.Logf("err: %v %T", err, err)
t.Errorf("Lookup got non-nil err and non-nil node, wanted at least one non-nil")
}
})
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index 9ff64a8b6..c3856094f 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -20,6 +20,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ gErr "gvisor.dev/gvisor/pkg/errors"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
@@ -32,7 +33,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/host"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
// inodeOperations implements fs.InodeOperations.
@@ -719,12 +719,12 @@ func (i *inodeOperations) configureMMap(file *fs.File, opts *memmap.MMapOpts) er
}
func init() {
- syserror.AddErrorUnwrapper(func(err error) (unix.Errno, bool) {
+ linuxerr.AddErrorUnwrapper(func(err error) (*gErr.Error, bool) {
if _, ok := err.(p9.ErrSocket); ok {
// Treat as an I/O error.
- return unix.EIO, true
+ return linuxerr.EIO, true
}
- return 0, false
+ return nil, false
})
}
diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go
index 88d83060c..2f8769f1e 100644
--- a/pkg/sentry/fs/gofer/path.go
+++ b/pkg/sentry/fs/gofer/path.go
@@ -25,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
- "gvisor.dev/gvisor/pkg/syserror"
)
// maxFilenameLen is the maximum length of a filename. This is dictated by 9P's
@@ -60,7 +59,7 @@ func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string
if cp.cacheNegativeDirents() {
return fs.NewNegativeDirent(name), nil
}
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
i.readdirMu.Unlock()
}
@@ -74,7 +73,7 @@ func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string
// is created over it.
return fs.NewNegativeDirent(name), nil
}
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
return nil, err
}
@@ -169,7 +168,7 @@ func (i *inodeOperations) Create(ctx context.Context, dir *fs.Inode, name string
hostFile.Close()
}
unopened.close(ctx)
- return nil, syserror.EIO
+ return nil, linuxerr.EIO
}
qid := qids[0]
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
index 24fc6305c..921612e9c 100644
--- a/pkg/sentry/fs/host/BUILD
+++ b/pkg/sentry/fs/host/BUILD
@@ -52,7 +52,6 @@ go_library(
"//pkg/sentry/uniqueid",
"//pkg/sync",
"//pkg/syserr",
- "//pkg/syserror",
"//pkg/tcpip",
"//pkg/unet",
"//pkg/usermem",
diff --git a/pkg/sentry/fs/host/file.go b/pkg/sentry/fs/host/file.go
index 77c08a7ce..1d0d95634 100644
--- a/pkg/sentry/fs/host/file.go
+++ b/pkg/sentry/fs/host/file.go
@@ -28,7 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -201,7 +200,7 @@ func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.I
writer := fd.NewReadWriter(f.iops.fileState.FD())
n, err := src.CopyInTo(ctx, safemem.FromIOWriter{writer})
if isBlockError(err) {
- err = syserror.ErrWouldBlock
+ err = linuxerr.ErrWouldBlock
}
return n, err
}
@@ -232,7 +231,7 @@ func (f *fileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IO
if n != 0 {
err = nil
} else {
- err = syserror.ErrWouldBlock
+ err = linuxerr.ErrWouldBlock
}
}
return n, err
diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go
index 5f6af2067..92d58e3e9 100644
--- a/pkg/sentry/fs/host/inode.go
+++ b/pkg/sentry/fs/host/inode.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -220,7 +219,7 @@ func (i *inodeOperations) Release(context.Context) {
// Lookup implements fs.InodeOperations.Lookup.
func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string) (*fs.Dirent, error) {
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
// Create implements fs.InodeOperations.Create.
@@ -400,7 +399,7 @@ func (i *inodeOperations) Getlink(context.Context, *fs.Inode) (*fs.Dirent, error
// StatFS implements fs.InodeOperations.StatFS.
func (i *inodeOperations) StatFS(context.Context) (fs.Info, error) {
- return fs.Info{}, syserror.ENOSYS
+ return fs.Info{}, linuxerr.ENOSYS
}
// AddLink implements fs.InodeOperations.AddLink.
diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go
index 6f38b25c3..4e561c5ed 100644
--- a/pkg/sentry/fs/host/tty.go
+++ b/pkg/sentry/fs/host/tty.go
@@ -24,7 +24,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/unimpl"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -327,7 +326,7 @@ func (t *TTYFileOperations) checkChange(ctx context.Context, sig linux.Signal) e
// If the signal is SIGTTIN, then we are attempting to read
// from the TTY. Don't send the signal and return EIO.
if sig == linux.SIGTTIN {
- return syserror.EIO
+ return linuxerr.EIO
}
// Otherwise, we are writing or changing terminal state. This is allowed.
@@ -336,7 +335,7 @@ func (t *TTYFileOperations) checkChange(ctx context.Context, sig linux.Signal) e
// If the process group is an orphan, return EIO.
if pg.IsOrphan() {
- return syserror.EIO
+ return linuxerr.EIO
}
// Otherwise, send the signal to the process group and return ERESTARTSYS.
@@ -349,7 +348,7 @@ func (t *TTYFileOperations) checkChange(ctx context.Context, sig linux.Signal) e
//
// Linux ignores the result of kill_pgrp().
_ = pg.SendSignal(kernel.SignalInfoPriv(sig))
- return syserror.ERESTARTSYS
+ return linuxerr.ERESTARTSYS
}
// LINT.ThenChange(../../fsimpl/host/tty.go)
diff --git a/pkg/sentry/fs/host/util.go b/pkg/sentry/fs/host/util.go
index e7db79189..f2a33cc14 100644
--- a/pkg/sentry/fs/host/util.go
+++ b/pkg/sentry/fs/host/util.go
@@ -96,7 +96,7 @@ type dirInfo struct {
// LINT.IfChange
// isBlockError unwraps os errors and checks if they are caused by EAGAIN or
-// EWOULDBLOCK. This is so they can be transformed into syserror.ErrWouldBlock.
+// EWOULDBLOCK. This is so they can be transformed into linuxerr.ErrWouldBlock.
func isBlockError(err error) bool {
if linuxerr.Equals(linuxerr.EAGAIN, err) || linuxerr.Equals(linuxerr.EWOULDBLOCK, err) {
return true
diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go
index ec204e5cf..2c6b9e9db 100644
--- a/pkg/sentry/fs/inode.go
+++ b/pkg/sentry/fs/inode.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
// Inode is a file system object that can be simultaneously referenced by different
@@ -357,7 +356,7 @@ func (i *Inode) SetTimestamps(ctx context.Context, d *Dirent, ts TimeSpec) error
// Truncate calls i.InodeOperations.Truncate with i as the Inode.
func (i *Inode) Truncate(ctx context.Context, d *Dirent, size int64) error {
if IsDir(i.StableAttr) {
- return syserror.EISDIR
+ return linuxerr.EISDIR
}
if i.overlay != nil {
diff --git a/pkg/sentry/fs/inode_operations.go b/pkg/sentry/fs/inode_operations.go
index 98e9fb2b1..0f8022906 100644
--- a/pkg/sentry/fs/inode_operations.go
+++ b/pkg/sentry/fs/inode_operations.go
@@ -66,7 +66,7 @@ type InodeOperations interface {
//
// * A nil Dirent and a non-nil error. If the reason that Lookup failed
// was because the name does not exist under Inode, then must return
- // syserror.ENOENT.
+ // linuxerr.ENOENT.
//
// * If name does not exist under dir and the file system wishes this
// fact to be cached, a non-nil Dirent containing a nil Inode and a
diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go
index c47b9ce58..21ad7fa69 100644
--- a/pkg/sentry/fs/inode_overlay.go
+++ b/pkg/sentry/fs/inode_overlay.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
- "gvisor.dev/gvisor/pkg/syserror"
)
func overlayHasWhiteout(ctx context.Context, parent *Inode, name string) bool {
@@ -103,7 +102,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name
// Upper fs is not OK with a negative Dirent
// being cached in the Dirent tree, so don't
// return one.
- return nil, false, syserror.ENOENT
+ return nil, false, linuxerr.ENOENT
}
entry, err := newOverlayEntry(ctx, upperInode, nil, false)
if err != nil {
@@ -165,7 +164,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name
if negativeUpperChild {
return NewNegativeDirent(name), false, nil
}
- return nil, false, syserror.ENOENT
+ return nil, false, linuxerr.ENOENT
}
// Did we find a lower Inode? Remember this because we may decide we don't
diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go
index ee28b0f99..51cd6cd37 100644
--- a/pkg/sentry/fs/inotify.go
+++ b/pkg/sentry/fs/inotify.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -141,7 +140,7 @@ func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ i
if i.events.Empty() {
// Nothing to read yet, tell caller to block.
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
var writeLen int64
@@ -179,7 +178,7 @@ func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ i
// WriteTo implements FileOperations.WriteTo.
func (*Inotify) WriteTo(context.Context, *File, io.Writer, int64, bool) (int64, error) {
- return 0, syserror.ENOSYS
+ return 0, linuxerr.ENOSYS
}
// Fsync implements FileOperations.Fsync.
@@ -189,7 +188,7 @@ func (*Inotify) Fsync(context.Context, *File, int64, int64, SyncType) error {
// ReadFrom implements FileOperations.ReadFrom.
func (*Inotify) ReadFrom(context.Context, *File, io.Reader, int64) (int64, error) {
- return 0, syserror.ENOSYS
+ return 0, linuxerr.ENOSYS
}
// Flush implements FileOperations.Flush.
diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD
index e6d74b949..bc75ae505 100644
--- a/pkg/sentry/fs/proc/BUILD
+++ b/pkg/sentry/fs/proc/BUILD
@@ -50,7 +50,6 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usage",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/tcpip/header",
"//pkg/tcpip/network/ipv4",
"//pkg/usermem",
diff --git a/pkg/sentry/fs/proc/exec_args.go b/pkg/sentry/fs/proc/exec_args.go
index 379429ab2..75dc5d204 100644
--- a/pkg/sentry/fs/proc/exec_args.go
+++ b/pkg/sentry/fs/proc/exec_args.go
@@ -107,7 +107,7 @@ func (f *execArgFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequen
return 0, linuxerr.EINVAL
}
- m, err := getTaskMM(f.t)
+ m, err := getTaskMMIncRef(f.t)
if err != nil {
return 0, err
}
diff --git a/pkg/sentry/fs/proc/fds.go b/pkg/sentry/fs/proc/fds.go
index e90da225a..e68bb46c0 100644
--- a/pkg/sentry/fs/proc/fds.go
+++ b/pkg/sentry/fs/proc/fds.go
@@ -20,12 +20,12 @@ import (
"strconv"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/device"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/syserror"
)
// LINT.IfChange
@@ -37,7 +37,7 @@ func walkDescriptors(t *kernel.Task, p string, toInode func(*fs.File, kernel.FDF
n, err := strconv.ParseUint(p, 10, 64)
if err != nil {
// Not found.
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
var file *fs.File
@@ -48,7 +48,7 @@ func walkDescriptors(t *kernel.Task, p string, toInode func(*fs.File, kernel.FDF
}
})
if file == nil {
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
return toInode(file, fdFlags), nil
}
diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go
index 546b57287..b9629c598 100644
--- a/pkg/sentry/fs/proc/proc.go
+++ b/pkg/sentry/fs/proc/proc.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package proc implements a partial in-memory file system for profs.
+// Package proc implements a partial in-memory file system for procfs.
package proc
import (
@@ -28,7 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/syserror"
)
// LINT.IfChange
@@ -125,7 +124,7 @@ func (s *self) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
if t := kernel.TaskFromContext(ctx); t != nil {
tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup())
if tgid == 0 {
- return "", syserror.ENOENT
+ return "", linuxerr.ENOENT
}
return strconv.FormatUint(uint64(tgid), 10), nil
}
@@ -149,7 +148,7 @@ func (s *threadSelf) Readlink(ctx context.Context, inode *fs.Inode) (string, err
tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup())
tid := s.pidns.IDOfTask(t)
if tid == 0 || tgid == 0 {
- return "", syserror.ENOENT
+ return "", linuxerr.ENOENT
}
return fmt.Sprintf("%d/task/%d", tgid, tid), nil
}
diff --git a/pkg/sentry/fs/proc/sys.go b/pkg/sentry/fs/proc/sys.go
index 085aa6d61..443b9a94c 100644
--- a/pkg/sentry/fs/proc/sys.go
+++ b/pkg/sentry/fs/proc/sys.go
@@ -109,6 +109,9 @@ func (p *proc) newKernelDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode
"shmall": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMALL, 10))),
"shmmax": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMAX, 10))),
"shmmni": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMNI, 10))),
+ "msgmni": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.MSGMNI, 10))),
+ "msgmax": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.MSGMAX, 10))),
+ "msgmnb": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.MSGMNB, 10))),
}
d := ramfs.NewDir(ctx, children, fs.RootOwner, fs.FilePermsFromMode(0555))
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index edd62b857..03f2a882d 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -35,17 +35,30 @@ import (
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
// LINT.IfChange
-// getTaskMM returns t's MemoryManager. If getTaskMM succeeds, the MemoryManager's
-// users count is incremented, and must be decremented by the caller when it is
-// no longer in use.
-func getTaskMM(t *kernel.Task) (*mm.MemoryManager, error) {
+// getTaskMM gets the kernel task's MemoryManager. No additional reference is
+// taken on mm here. This is safe because MemoryManager.destroy is required to
+// leave the MemoryManager in a state where it's still usable as a
+// DynamicBytesSource.
+func getTaskMM(t *kernel.Task) *mm.MemoryManager {
+ var tmm *mm.MemoryManager
+ t.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ tmm = mm
+ }
+ })
+ return tmm
+}
+
+// getTaskMMIncRef returns t's MemoryManager. If getTaskMMIncRef succeeds, the
+// MemoryManager's users count is incremented, and must be decremented by the
+// caller when it is no longer in use.
+func getTaskMMIncRef(t *kernel.Task) (*mm.MemoryManager, error) {
if t.ExitState() == kernel.TaskExitDead {
return nil, linuxerr.ESRCH
}
@@ -182,7 +195,7 @@ func (f *subtasksFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dentry
tasks := f.t.ThreadGroup().MemberIDs(f.pidns)
if len(tasks) == 0 {
- return offset, syserror.ENOENT
+ return offset, linuxerr.ENOENT
}
if offset == 0 {
@@ -234,15 +247,15 @@ var _ fs.FileOperations = (*subtasksFile)(nil)
func (s *subtasks) Lookup(ctx context.Context, dir *fs.Inode, p string) (*fs.Dirent, error) {
tid, err := strconv.ParseUint(p, 10, 32)
if err != nil {
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
task := s.p.pidns.TaskWithID(kernel.ThreadID(tid))
if task == nil {
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
if task.ThreadGroup() != s.t.ThreadGroup() {
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
td := s.p.newTaskDir(ctx, task, dir.MountSource, false)
@@ -270,21 +283,18 @@ func (e *exe) executable() (file fsbridge.File, err error) {
if err := checkTaskState(e.t); err != nil {
return nil, err
}
- e.t.WithMuLocked(func(t *kernel.Task) {
- mm := t.MemoryManager()
- if mm == nil {
- err = linuxerr.EACCES
- return
- }
+ mm := getTaskMM(e.t)
+ if mm == nil {
+ return nil, linuxerr.EACCES
+ }
- // The MemoryManager may be destroyed, in which case
- // MemoryManager.destroy will simply set the executable to nil
- // (with locks held).
- file = mm.Executable()
- if file == nil {
- err = linuxerr.ESRCH
- }
- })
+ // The MemoryManager may be destroyed, in which case
+ // MemoryManager.destroy will simply set the executable to nil
+ // (with locks held).
+ file = mm.Executable()
+ if file == nil {
+ err = linuxerr.ESRCH
+ }
return
}
@@ -464,7 +474,7 @@ func (m *memDataFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequen
if dst.NumBytes() == 0 {
return 0, nil
}
- mm, err := getTaskMM(m.t)
+ mm, err := getTaskMMIncRef(m.t)
if err != nil {
return 0, nil
}
@@ -479,7 +489,7 @@ func (m *memDataFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequen
return int64(n), nil
}
if readErr != nil {
- return 0, syserror.EIO
+ return 0, linuxerr.EIO
}
return 0, nil
}
@@ -495,22 +505,9 @@ func newMaps(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inod
return newProcInode(ctx, seqfile.NewSeqFile(ctx, &mapsData{t}), msrc, fs.SpecialFile, t)
}
-func (md *mapsData) mm() *mm.MemoryManager {
- var tmm *mm.MemoryManager
- md.t.WithMuLocked(func(t *kernel.Task) {
- if mm := t.MemoryManager(); mm != nil {
- // No additional reference is taken on mm here. This is safe
- // because MemoryManager.destroy is required to leave the
- // MemoryManager in a state where it's still usable as a SeqSource.
- tmm = mm
- }
- })
- return tmm
-}
-
// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
func (md *mapsData) NeedsUpdate(generation int64) bool {
- if mm := md.mm(); mm != nil {
+ if mm := getTaskMM(md.t); mm != nil {
return mm.NeedsUpdate(generation)
}
return true
@@ -518,7 +515,7 @@ func (md *mapsData) NeedsUpdate(generation int64) bool {
// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
func (md *mapsData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
- if mm := md.mm(); mm != nil {
+ if mm := getTaskMM(md.t); mm != nil {
return mm.ReadMapsSeqFileData(ctx, h)
}
return []seqfile.SeqData{}, 0
@@ -535,22 +532,9 @@ func newSmaps(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Ino
return newProcInode(ctx, seqfile.NewSeqFile(ctx, &smapsData{t}), msrc, fs.SpecialFile, t)
}
-func (sd *smapsData) mm() *mm.MemoryManager {
- var tmm *mm.MemoryManager
- sd.t.WithMuLocked(func(t *kernel.Task) {
- if mm := t.MemoryManager(); mm != nil {
- // No additional reference is taken on mm here. This is safe
- // because MemoryManager.destroy is required to leave the
- // MemoryManager in a state where it's still usable as a SeqSource.
- tmm = mm
- }
- })
- return tmm
-}
-
// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
func (sd *smapsData) NeedsUpdate(generation int64) bool {
- if mm := sd.mm(); mm != nil {
+ if mm := getTaskMM(sd.t); mm != nil {
return mm.NeedsUpdate(generation)
}
return true
@@ -558,7 +542,7 @@ func (sd *smapsData) NeedsUpdate(generation int64) bool {
// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
func (sd *smapsData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
- if mm := sd.mm(); mm != nil {
+ if mm := getTaskMM(sd.t); mm != nil {
return mm.ReadSmapsSeqFileData(ctx, h)
}
return []seqfile.SeqData{}, 0
@@ -628,12 +612,10 @@ func (s *taskStatData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle)
fmt.Fprintf(&buf, "%d ", linux.ClockTFromDuration(s.t.StartTime().Sub(s.t.Kernel().Timekeeper().BootTime())))
var vss, rss uint64
- s.t.WithMuLocked(func(t *kernel.Task) {
- if mm := t.MemoryManager(); mm != nil {
- vss = mm.VirtualMemorySize()
- rss = mm.ResidentSetSize()
- }
- })
+ if mm := getTaskMM(s.t); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ }
fmt.Fprintf(&buf, "%d %d ", vss, rss/hostarch.PageSize)
// rsslim.
@@ -678,12 +660,10 @@ func (s *statmData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([
}
var vss, rss uint64
- s.t.WithMuLocked(func(t *kernel.Task) {
- if mm := t.MemoryManager(); mm != nil {
- vss = mm.VirtualMemorySize()
- rss = mm.ResidentSetSize()
- }
- })
+ if mm := getTaskMM(s.t); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ }
var buf bytes.Buffer
fmt.Fprintf(&buf, "%d %d 0 0 0 0 0\n", vss/hostarch.PageSize, rss/hostarch.PageSize)
@@ -735,12 +715,13 @@ func (s *statusData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) (
if fdTable := t.FDTable(); fdTable != nil {
fds = fdTable.CurrentMaxFDs()
}
- if mm := t.MemoryManager(); mm != nil {
- vss = mm.VirtualMemorySize()
- rss = mm.ResidentSetSize()
- data = mm.VirtualDataSize()
- }
})
+
+ if mm := getTaskMM(s.t); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ data = mm.VirtualDataSize()
+ }
fmt.Fprintf(&buf, "FDSize:\t%d\n", fds)
fmt.Fprintf(&buf, "VmSize:\t%d kB\n", vss>>10)
fmt.Fprintf(&buf, "VmRSS:\t%d kB\n", rss>>10)
@@ -926,7 +907,7 @@ func (f *auxvecFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequenc
return 0, linuxerr.EINVAL
}
- m, err := getTaskMM(f.t)
+ m, err := getTaskMMIncRef(f.t)
if err != nil {
return 0, err
}
diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD
index b46567cf8..bfff010c5 100644
--- a/pkg/sentry/fs/ramfs/BUILD
+++ b/pkg/sentry/fs/ramfs/BUILD
@@ -21,7 +21,6 @@ go_library(
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/socket/unix/transport",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/waiter",
"@org_golang_x_sys//unix:go_default_library",
],
diff --git a/pkg/sentry/fs/ramfs/dir.go b/pkg/sentry/fs/ramfs/dir.go
index 33023af77..b1fadee7a 100644
--- a/pkg/sentry/fs/ramfs/dir.go
+++ b/pkg/sentry/fs/ramfs/dir.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
// CreateOps represents operations to create different file types.
@@ -284,9 +283,9 @@ func (d *Dir) walkLocked(ctx context.Context, p string) (*fs.Inode, error) {
return inode, nil
}
- // fs.InodeOperations.Lookup returns syserror.ENOENT if p
+ // fs.InodeOperations.Lookup returns linuxerr.ENOENT if p
// does not exist.
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
// createInodeOperationsCommon creates a new child node at this dir by calling
diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go
index fff4befb2..266140f6f 100644
--- a/pkg/sentry/fs/splice.go
+++ b/pkg/sentry/fs/splice.go
@@ -20,7 +20,6 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
- "gvisor.dev/gvisor/pkg/syserror"
)
// Splice moves data to this file, directly from another.
@@ -55,26 +54,26 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64,
case dst.UniqueID < src.UniqueID:
// Acquire dst first.
if !dst.mu.Lock(ctx) {
- return 0, syserror.ErrInterrupted
+ return 0, linuxerr.ErrInterrupted
}
if !src.mu.Lock(ctx) {
dst.mu.Unlock()
- return 0, syserror.ErrInterrupted
+ return 0, linuxerr.ErrInterrupted
}
case dst.UniqueID > src.UniqueID:
// Acquire src first.
if !src.mu.Lock(ctx) {
- return 0, syserror.ErrInterrupted
+ return 0, linuxerr.ErrInterrupted
}
if !dst.mu.Lock(ctx) {
src.mu.Unlock()
- return 0, syserror.ErrInterrupted
+ return 0, linuxerr.ErrInterrupted
}
case dst.UniqueID == src.UniqueID:
// Acquire only one lock; it's the same file. This is a
// bit of a edge case, but presumably it's possible.
if !dst.mu.Lock(ctx) {
- return 0, syserror.ErrInterrupted
+ return 0, linuxerr.ErrInterrupted
}
srcLock = false // Only need one unlock.
}
@@ -84,13 +83,13 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64,
case dstLock:
// Acquire only dst.
if !dst.mu.Lock(ctx) {
- return 0, syserror.ErrInterrupted
+ return 0, linuxerr.ErrInterrupted
}
opts.DstStart = dst.offset // Safe: locked.
case srcLock:
// Acquire only src.
if !src.mu.Lock(ctx) {
- return 0, syserror.ErrInterrupted
+ return 0, linuxerr.ErrInterrupted
}
opts.SrcStart = src.offset // Safe: locked.
}
@@ -108,7 +107,7 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64,
limit, ok := dst.checkLimit(ctx, opts.DstStart)
switch {
case ok && limit == 0:
- err = syserror.ErrExceedsFileSizeLimit
+ err = linuxerr.ErrExceedsFileSizeLimit
case ok && limit < opts.Length:
opts.Length = limit // Cap the write.
}
diff --git a/pkg/sentry/fs/timerfd/BUILD b/pkg/sentry/fs/timerfd/BUILD
index 0148b33cf..e61115932 100644
--- a/pkg/sentry/fs/timerfd/BUILD
+++ b/pkg/sentry/fs/timerfd/BUILD
@@ -14,7 +14,6 @@ go_library(
"//pkg/sentry/fs/anon",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/kernel/time",
- "//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go
index 093a14c1f..1c8518d71 100644
--- a/pkg/sentry/fs/timerfd/timerfd.go
+++ b/pkg/sentry/fs/timerfd/timerfd.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/anon"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -134,7 +133,7 @@ func (t *TimerOperations) Read(ctx context.Context, file *fs.File, dst usermem.I
}
return sizeofUint64, nil
}
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
// Write implements fs.FileOperations.Write.
diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD
index 5933cb67b..9e9dc06f3 100644
--- a/pkg/sentry/fs/tty/BUILD
+++ b/pkg/sentry/fs/tty/BUILD
@@ -31,7 +31,6 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/unimpl",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go
index 3242dcb6a..5716e2ee9 100644
--- a/pkg/sentry/fs/tty/dir.go
+++ b/pkg/sentry/fs/tty/dir.go
@@ -29,7 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -155,12 +154,12 @@ func (d *dirInodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name str
n, err := strconv.ParseUint(name, 10, 32)
if err != nil {
// Not found.
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
s, ok := d.replicas[uint32(n)]
if !ok {
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
s.IncRef()
@@ -235,7 +234,7 @@ func (d *dirInodeOperations) allocateTerminal(ctx context.Context) (*Terminal, e
n := d.next
if n == math.MaxUint32 {
- return nil, syserror.ENOMEM
+ return nil, linuxerr.ENOMEM
}
if _, ok := d.replicas[n]; ok {
@@ -335,10 +334,10 @@ func (df *dirFileOperations) Readdir(ctx context.Context, file *fs.File, seriali
// Read implements FileOperations.Read
func (df *dirFileOperations) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
- return 0, syserror.EISDIR
+ return 0, linuxerr.EISDIR
}
// Write implements FileOperations.Write.
func (df *dirFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
- return 0, syserror.EISDIR
+ return 0, linuxerr.EISDIR
}
diff --git a/pkg/sentry/fs/tty/line_discipline.go b/pkg/sentry/fs/tty/line_discipline.go
index 3ba02c218..f9fca6d8e 100644
--- a/pkg/sentry/fs/tty/line_discipline.go
+++ b/pkg/sentry/fs/tty/line_discipline.go
@@ -20,10 +20,10 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -193,7 +193,7 @@ func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSeque
}
return n, nil
}
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
func (l *lineDiscipline) inputQueueWrite(ctx context.Context, src usermem.IOSequence) (int64, error) {
@@ -207,7 +207,7 @@ func (l *lineDiscipline) inputQueueWrite(ctx context.Context, src usermem.IOSequ
l.replicaWaiter.Notify(waiter.ReadableEvents)
return n, nil
}
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
func (l *lineDiscipline) outputQueueReadSize(t *kernel.Task, args arch.SyscallArguments) error {
@@ -228,7 +228,7 @@ func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequ
}
return n, nil
}
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
func (l *lineDiscipline) outputQueueWrite(ctx context.Context, src usermem.IOSequence) (int64, error) {
@@ -242,7 +242,7 @@ func (l *lineDiscipline) outputQueueWrite(ctx context.Context, src usermem.IOSeq
l.masterWaiter.Notify(waiter.ReadableEvents)
return n, nil
}
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
// transformer is a helper interface to make it easier to stateify queue.
diff --git a/pkg/sentry/fs/tty/queue.go b/pkg/sentry/fs/tty/queue.go
index 11d6c15d0..25d3c887e 100644
--- a/pkg/sentry/fs/tty/queue.go
+++ b/pkg/sentry/fs/tty/queue.go
@@ -17,12 +17,12 @@ package tty
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -110,7 +110,7 @@ func (q *queue) read(ctx context.Context, dst usermem.IOSequence, l *lineDiscipl
defer q.mu.Unlock()
if !q.readable {
- return 0, false, syserror.ErrWouldBlock
+ return 0, false, linuxerr.ErrWouldBlock
}
if dst.NumBytes() > canonMaxBytes {
@@ -155,7 +155,7 @@ func (q *queue) write(ctx context.Context, src usermem.IOSequence, l *lineDiscip
room := waitBufMaxBytes - q.waitBufLen
// If out of room, return EAGAIN.
if room == 0 && copyLen > 0 {
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
// Cap the size of the wait buffer.
if copyLen > room {
diff --git a/pkg/sentry/fs/user/BUILD b/pkg/sentry/fs/user/BUILD
index 4acc73ee0..23b5508fd 100644
--- a/pkg/sentry/fs/user/BUILD
+++ b/pkg/sentry/fs/user/BUILD
@@ -19,7 +19,6 @@ go_library(
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
- "//pkg/syserror",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/fs/user/path.go b/pkg/sentry/fs/user/path.go
index f6eaab2bd..67a9adfd7 100644
--- a/pkg/sentry/fs/user/path.go
+++ b/pkg/sentry/fs/user/path.go
@@ -28,7 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
)
// ResolveExecutablePath resolves the given executable name given the working
@@ -81,7 +80,7 @@ func resolve(ctx context.Context, mns *fs.MountNamespace, paths []string, name s
root := fs.RootFromContext(ctx)
if root == nil {
// Caller has no root. Don't bother traversing anything.
- return "", syserror.ENOENT
+ return "", linuxerr.ENOENT
}
defer root.DecRef(ctx)
for _, p := range paths {
@@ -117,7 +116,7 @@ func resolve(ctx context.Context, mns *fs.MountNamespace, paths []string, name s
}
// Couldn't find it.
- return "", syserror.ENOENT
+ return "", linuxerr.ENOENT
}
func resolveVFS2(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNamespace, paths []string, name string) (string, error) {
@@ -156,7 +155,7 @@ func resolveVFS2(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNam
}
// Couldn't find it.
- return "", syserror.ENOENT
+ return "", linuxerr.ENOENT
}
// getPath returns the PATH as a slice of strings given the environment
diff --git a/pkg/sentry/fsimpl/cgroupfs/BUILD b/pkg/sentry/fsimpl/cgroupfs/BUILD
index 4c9c5b344..e5fdcc776 100644
--- a/pkg/sentry/fsimpl/cgroupfs/BUILD
+++ b/pkg/sentry/fsimpl/cgroupfs/BUILD
@@ -32,6 +32,7 @@ go_library(
"//pkg/context",
"//pkg/coverage",
"//pkg/errors/linuxerr",
+ "//pkg/fspath",
"//pkg/log",
"//pkg/refs",
"//pkg/refsvfs2",
@@ -43,7 +44,6 @@ go_library(
"//pkg/sentry/usage",
"//pkg/sentry/vfs",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/fsimpl/cgroupfs/base.go b/pkg/sentry/fsimpl/cgroupfs/base.go
index 4290ffe0d..71bb0a9c8 100644
--- a/pkg/sentry/fsimpl/cgroupfs/base.go
+++ b/pkg/sentry/fsimpl/cgroupfs/base.go
@@ -88,7 +88,6 @@ type controller interface {
// +stateify savable
type cgroupInode struct {
dir
- fs *filesystem
// ts is the list of tasks in this cgroup. The kernel is responsible for
// removing tasks from this list before they're destroyed, so any tasks on
@@ -102,9 +101,10 @@ var _ kernel.CgroupImpl = (*cgroupInode)(nil)
func (fs *filesystem) newCgroupInode(ctx context.Context, creds *auth.Credentials) kernfs.Inode {
c := &cgroupInode{
- fs: fs,
- ts: make(map[*kernel.Task]struct{}),
+ dir: dir{fs: fs},
+ ts: make(map[*kernel.Task]struct{}),
}
+ c.dir.cgi = c
contents := make(map[string]kernfs.Inode)
contents["cgroup.procs"] = fs.newControllerFile(ctx, creds, &cgroupProcsData{c})
@@ -115,8 +115,7 @@ func (fs *filesystem) newCgroupInode(ctx context.Context, creds *auth.Credential
}
c.dir.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|linux.FileMode(0555))
- c.dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
- c.dir.InitRefs()
+ c.dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true})
c.dir.IncLinks(c.dir.OrderedChildren.Populate(contents))
atomic.AddUint64(&fs.numCgroups, 1)
diff --git a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
index 22c8b7fda..edc3b50b9 100644
--- a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
+++ b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
@@ -32,7 +32,8 @@
// controllers associated with them.
//
// Since cgroupfs doesn't allow hardlinks, there is a unique mapping between
-// cgroupfs dentries and inodes.
+// cgroupfs dentries and inodes. Thus, cgroupfs inodes don't need to be ref
+// counted and exist until they're unlinked once or the FS is destroyed.
//
// # Synchronization
//
@@ -48,10 +49,11 @@
// Lock order:
//
// kernel.CgroupRegistry.mu
-// cgroupfs.filesystem.mu
-// kernel.TaskSet.mu
-// kernel.Task.mu
-// cgroupfs.filesystem.tasksMu.
+// kernfs.filesystem.mu
+// kernel.TaskSet.mu
+// kernel.Task.mu
+// cgroupfs.filesystem.tasksMu.
+// cgroupfs.dir.OrderedChildren.mu
package cgroupfs
import (
@@ -63,6 +65,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -108,6 +111,7 @@ type FilesystemType struct{}
// +stateify savable
type InternalData struct {
DefaultControlValues map[string]int64
+ InitialCgroupPath string
}
// filesystem implements vfs.FilesystemImpl and kernel.cgroupFS.
@@ -134,6 +138,11 @@ type filesystem struct {
numCgroups uint64 // Protected by atomic ops.
root *kernfs.Dentry
+ // effectiveRoot is the initial cgroup new tasks are created in. Unless
+ // overwritten by internal mount options, root == effectiveRoot. If
+ // effectiveRoot != root, an extra reference is held on effectiveRoot for
+ // the lifetime of the filesystem.
+ effectiveRoot *kernfs.Dentry
// tasksMu serializes task membership changes across all cgroups within a
// filesystem.
@@ -229,6 +238,9 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
fs := vfsfs.Impl().(*filesystem)
ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: mounting new view to hierarchy %v", fs.hierarchyID)
fs.root.IncRef()
+ if fs.effectiveRoot != fs.root {
+ fs.effectiveRoot.IncRef()
+ }
return vfsfs, fs.root.VFSDentry(), nil
}
@@ -245,8 +257,8 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
var defaults map[string]int64
if opts.InternalData != nil {
- ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: default control values: %v", defaults)
defaults = opts.InternalData.(*InternalData).DefaultControlValues
+ ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: default control values: %v", defaults)
}
for _, ty := range wantControllers {
@@ -286,6 +298,14 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
var rootD kernfs.Dentry
rootD.InitRoot(&fs.Filesystem, root)
fs.root = &rootD
+ fs.effectiveRoot = fs.root
+
+ if err := fs.prepareInitialCgroup(ctx, vfsObj, opts); err != nil {
+ ctx.Warningf("cgroupfs.FilesystemType.GetFilesystem: failed to prepare initial cgroup: %v", err)
+ rootD.DecRef(ctx)
+ fs.VFSFilesystem().DecRef(ctx)
+ return nil, nil, err
+ }
// Register controllers. The registry may be modified concurrently, so if we
// get an error, we raced with someone else who registered the same
@@ -303,10 +323,47 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
return fs.VFSFilesystem(), rootD.VFSDentry(), nil
}
+// prepareInitialCgroup creates the initial cgroup according to opts. An initial
+// cgroup is optional, and if not specified, this function is a no-op.
+func (fs *filesystem) prepareInitialCgroup(ctx context.Context, vfsObj *vfs.VirtualFilesystem, opts vfs.GetFilesystemOptions) error {
+ if opts.InternalData == nil {
+ return nil
+ }
+ initPathStr := opts.InternalData.(*InternalData).InitialCgroupPath
+ if initPathStr == "" {
+ return nil
+ }
+ ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: initial cgroup path: %v", initPathStr)
+ initPath := fspath.Parse(initPathStr)
+ if !initPath.Absolute || !initPath.HasComponents() {
+ ctx.Warningf("cgroupfs.FilesystemType.GetFilesystem: initial cgroup path invalid: %+v", initPath)
+ return linuxerr.EINVAL
+ }
+
+ // Have initial cgroup target, create the tree.
+ cgDir := fs.root.Inode().(*cgroupInode)
+ for pit := initPath.Begin; pit.Ok(); pit = pit.Next() {
+ cgDirI, err := cgDir.NewDir(ctx, pit.String(), vfs.MkdirOptions{})
+ if err != nil {
+ return err
+ }
+ cgDir = cgDirI.(*cgroupInode)
+ }
+
+ // Walk to target dentry.
+ initDentry, err := fs.root.WalkDentryTree(ctx, vfsObj, initPath)
+ if err != nil {
+ ctx.Warningf("cgroupfs.FilesystemType.GetFilesystem: initial cgroup dentry not found: %v", err)
+ return linuxerr.ENOENT
+ }
+ fs.effectiveRoot = initDentry // Reference from WalkDentryTree transferred here.
+ return nil
+}
+
func (fs *filesystem) rootCgroup() kernel.Cgroup {
return kernel.Cgroup{
- Dentry: fs.root,
- CgroupImpl: fs.root.Inode().(kernel.CgroupImpl),
+ Dentry: fs.effectiveRoot,
+ CgroupImpl: fs.effectiveRoot.Inode().(kernel.CgroupImpl),
}
}
@@ -320,6 +377,10 @@ func (fs *filesystem) Release(ctx context.Context) {
r.Unregister(fs.hierarchyID)
}
+ if fs.root != fs.effectiveRoot {
+ fs.effectiveRoot.DecRef(ctx)
+ }
+
fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
fs.Filesystem.Release(ctx)
}
@@ -346,15 +407,18 @@ func (*implStatFS) StatFS(context.Context, *vfs.Filesystem) (linux.Statfs, error
//
// +stateify savable
type dir struct {
- dirRefs
+ kernfs.InodeNoopRefCount
kernfs.InodeAlwaysValid
kernfs.InodeAttrs
kernfs.InodeNotSymlink
- kernfs.InodeDirectoryNoNewChildren // TODO(b/183137098): Implement mkdir.
+ kernfs.InodeDirectoryNoNewChildren
kernfs.OrderedChildren
implStatFS
locks vfs.FileLocks
+
+ fs *filesystem // Immutable.
+ cgi *cgroupInode // Immutable.
}
// Keep implements kernfs.Inode.Keep.
@@ -378,9 +442,100 @@ func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry
return fd.VFSFileDescription(), nil
}
-// DecRef implements kernfs.Inode.DecRef.
-func (d *dir) DecRef(ctx context.Context) {
- d.dirRefs.DecRef(func() { d.Destroy(ctx) })
+// NewDir implements kernfs.Inode.NewDir.
+func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (kernfs.Inode, error) {
+ // "Do not accept '\n' to prevent making /proc/<pid>/cgroup unparsable."
+ // -- Linux, kernel/cgroup.c:cgroup_mkdir().
+ if strings.Contains(name, "\n") {
+ return nil, linuxerr.EINVAL
+ }
+ return d.OrderedChildren.Inserter(name, func() kernfs.Inode {
+ d.IncLinks(1)
+ return d.fs.newCgroupInode(ctx, auth.CredentialsFromContext(ctx))
+ })
+}
+
+// Rename implements kernfs.Inode.Rename. Cgroupfs only allows renaming of
+// cgroup directories, and the rename may only change the name within the same
+// parent. See linux, kernel/cgroup.c:cgroup_rename().
+func (d *dir) Rename(ctx context.Context, oldname, newname string, child, dst kernfs.Inode) error {
+ if _, ok := child.(*cgroupInode); !ok {
+ // Not a cgroup directory. Control files are backed by different types.
+ return linuxerr.ENOTDIR
+ }
+
+ dstCGInode, ok := dst.(*cgroupInode)
+ if !ok {
+ // Not a cgroup inode, so definitely can't be *this* inode.
+ return linuxerr.EIO
+ }
+ // Note: We're intentionally comparing addresses, since two different dirs
+ // could plausibly be identical in memory, but would occupy different
+ // locations in memory.
+ if d != &dstCGInode.dir {
+ // Destination dir is a different cgroup inode. Cross directory renames
+ // aren't allowed.
+ return linuxerr.EIO
+ }
+
+ // Rename moves oldname to newname within d. Proceed.
+ return d.OrderedChildren.Rename(ctx, oldname, newname, child, dst)
+}
+
+// Unlink implements kernfs.Inode.Unlink. Cgroupfs disallows unlink, as the only
+// files in the filesystem are control files, which can't be deleted.
+func (d *dir) Unlink(ctx context.Context, name string, child kernfs.Inode) error {
+ return linuxerr.EPERM
+}
+
+// hasChildrenLocked returns whether the cgroup dir contains any objects that
+// prevent it from being deleted.
+func (d *dir) hasChildrenLocked() bool {
+ // Subdirs take a link on the parent, so checks if there are any direct
+ // children cgroups. Exclude the dir's self link and the link from ".".
+ if d.InodeAttrs.Links()-2 > 0 {
+ return true
+ }
+ return len(d.cgi.ts) > 0
+}
+
+// HasChildren implements kernfs.Inode.HasChildren.
+//
+// The empty check for a cgroupfs directory is unlike a regular directory since
+// a cgroupfs directory will always have control files. A cgroupfs directory can
+// be deleted if cgroup contains no tasks and has no sub-cgroups.
+func (d *dir) HasChildren() bool {
+ d.fs.tasksMu.RLock()
+ defer d.fs.tasksMu.RUnlock()
+ return d.hasChildrenLocked()
+}
+
+// RmDir implements kernfs.Inode.RmDir.
+func (d *dir) RmDir(ctx context.Context, name string, child kernfs.Inode) error {
+ // Unlike a normal directory, we need to recheck if d is empty again, since
+ // vfs/kernfs can't stop tasks from entering or leaving the cgroup.
+ d.fs.tasksMu.RLock()
+ defer d.fs.tasksMu.RUnlock()
+
+ cgi, ok := child.(*cgroupInode)
+ if !ok {
+ return linuxerr.ENOTDIR
+ }
+ if cgi.dir.hasChildrenLocked() {
+ return linuxerr.ENOTEMPTY
+ }
+
+ // Disallow deletion of the effective root cgroup.
+ if cgi == d.fs.effectiveRoot.Inode().(*cgroupInode) {
+ ctx.Warningf("Cannot delete initial cgroup for new tasks %q", d.fs.effectiveRoot.FSLocalPath())
+ return linuxerr.EBUSY
+ }
+
+ err := d.OrderedChildren.RmDir(ctx, name, child)
+ if err == nil {
+ d.InodeAttrs.DecLinks()
+ }
+ return err
}
// controllerFile represents a generic control file that appears within a cgroup
diff --git a/pkg/sentry/fsimpl/cgroupfs/memory.go b/pkg/sentry/fsimpl/cgroupfs/memory.go
index 485c98376..d880c9bc4 100644
--- a/pkg/sentry/fsimpl/cgroupfs/memory.go
+++ b/pkg/sentry/fsimpl/cgroupfs/memory.go
@@ -31,22 +31,34 @@ import (
type memoryController struct {
controllerCommon
- limitBytes int64
+ limitBytes int64
+ softLimitBytes int64
+ moveChargeAtImmigrate int64
}
var _ controller = (*memoryController)(nil)
func newMemoryController(fs *filesystem, defaults map[string]int64) *memoryController {
c := &memoryController{
- // Linux sets this to (PAGE_COUNTER_MAX * PAGE_SIZE) by default, which
- // is ~ 2**63 on a 64-bit system. So essentially, inifinity. The exact
- // value isn't very important.
- limitBytes: math.MaxInt64,
+ // Linux sets these limits to (PAGE_COUNTER_MAX * PAGE_SIZE) by default,
+ // which is ~ 2**63 on a 64-bit system. So essentially, inifinity. The
+ // exact value isn't very important.
+
+ limitBytes: math.MaxInt64,
+ softLimitBytes: math.MaxInt64,
}
- if val, ok := defaults["memory.limit_in_bytes"]; ok {
- c.limitBytes = val
- delete(defaults, "memory.limit_in_bytes")
+
+ consumeDefault := func(name string, valPtr *int64) {
+ if val, ok := defaults[name]; ok {
+ *valPtr = val
+ delete(defaults, name)
+ }
}
+
+ consumeDefault("memory.limit_in_bytes", &c.limitBytes)
+ consumeDefault("memory.soft_limit_in_bytes", &c.softLimitBytes)
+ consumeDefault("memory.move_charge_at_immigrate", &c.moveChargeAtImmigrate)
+
c.controllerCommon.init(controllerMemory, fs)
return c
}
@@ -55,6 +67,8 @@ func newMemoryController(fs *filesystem, defaults map[string]int64) *memoryContr
func (c *memoryController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) {
contents["memory.usage_in_bytes"] = c.fs.newControllerFile(ctx, creds, &memoryUsageInBytesData{})
contents["memory.limit_in_bytes"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.limitBytes))
+ contents["memory.soft_limit_in_bytes"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.softLimitBytes))
+ contents["memory.move_charge_at_immigrate"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.moveChargeAtImmigrate))
}
// +stateify savable
diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD
index f981ff296..e0b879339 100644
--- a/pkg/sentry/fsimpl/devpts/BUILD
+++ b/pkg/sentry/fsimpl/devpts/BUILD
@@ -45,7 +45,6 @@ go_library(
"//pkg/sentry/unimpl",
"//pkg/sentry/vfs",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go
index 7a488e9fd..e711debcb 100644
--- a/pkg/sentry/fsimpl/devpts/devpts.go
+++ b/pkg/sentry/fsimpl/devpts/devpts.go
@@ -29,7 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
)
// Name is the filesystem name.
@@ -180,7 +179,7 @@ func (i *rootInode) allocateTerminal(ctx context.Context, creds *auth.Credential
i.mu.Lock()
defer i.mu.Unlock()
if i.nextIdx == math.MaxUint32 {
- return nil, syserror.ENOMEM
+ return nil, linuxerr.ENOMEM
}
idx := i.nextIdx
i.nextIdx++
@@ -241,7 +240,7 @@ func (i *rootInode) Lookup(ctx context.Context, name string) (kernfs.Inode, erro
// Not a static entry.
idx, err := strconv.ParseUint(name, 10, 32)
if err != nil {
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
i.mu.Lock()
defer i.mu.Unlock()
@@ -250,7 +249,7 @@ func (i *rootInode) Lookup(ctx context.Context, name string) (kernfs.Inode, erro
return ri, nil
}
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
// IterDirents implements kernfs.Inode.IterDirents.
diff --git a/pkg/sentry/fsimpl/devpts/line_discipline.go b/pkg/sentry/fsimpl/devpts/line_discipline.go
index 9cb21e83b..609623f9f 100644
--- a/pkg/sentry/fsimpl/devpts/line_discipline.go
+++ b/pkg/sentry/fsimpl/devpts/line_discipline.go
@@ -20,10 +20,10 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -203,7 +203,7 @@ func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSeque
} else if notifyEcho {
l.masterWaiter.Notify(waiter.ReadableEvents)
}
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
func (l *lineDiscipline) inputQueueWrite(ctx context.Context, src usermem.IOSequence) (int64, error) {
@@ -220,7 +220,7 @@ func (l *lineDiscipline) inputQueueWrite(ctx context.Context, src usermem.IOSequ
l.replicaWaiter.Notify(waiter.ReadableEvents)
return n, nil
}
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
func (l *lineDiscipline) outputQueueReadSize(t *kernel.Task, io usermem.IO, args arch.SyscallArguments) error {
@@ -242,7 +242,7 @@ func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequ
}
return n, nil
}
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
func (l *lineDiscipline) outputQueueWrite(ctx context.Context, src usermem.IOSequence) (int64, error) {
@@ -257,7 +257,7 @@ func (l *lineDiscipline) outputQueueWrite(ctx context.Context, src usermem.IOSeq
l.masterWaiter.Notify(waiter.ReadableEvents)
return n, nil
}
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
// transformer is a helper interface to make it easier to stateify queue.
diff --git a/pkg/sentry/fsimpl/devpts/queue.go b/pkg/sentry/fsimpl/devpts/queue.go
index ff1d89955..85aeefa43 100644
--- a/pkg/sentry/fsimpl/devpts/queue.go
+++ b/pkg/sentry/fsimpl/devpts/queue.go
@@ -17,12 +17,12 @@ package devpts
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -110,7 +110,7 @@ func (q *queue) read(ctx context.Context, dst usermem.IOSequence, l *lineDiscipl
defer q.mu.Unlock()
if !q.readable {
- return 0, false, false, syserror.ErrWouldBlock
+ return 0, false, false, linuxerr.ErrWouldBlock
}
if dst.NumBytes() > canonMaxBytes {
@@ -156,7 +156,7 @@ func (q *queue) write(ctx context.Context, src usermem.IOSequence, l *lineDiscip
room := waitBufMaxBytes - q.waitBufLen
// If out of room, return EAGAIN.
if room == 0 && copyLen > 0 {
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
// Cap the size of the wait buffer.
if copyLen > room {
diff --git a/pkg/sentry/fsimpl/eventfd/BUILD b/pkg/sentry/fsimpl/eventfd/BUILD
index c09fdc7f9..1cb049a29 100644
--- a/pkg/sentry/fsimpl/eventfd/BUILD
+++ b/pkg/sentry/fsimpl/eventfd/BUILD
@@ -9,11 +9,11 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/fdnotifier",
"//pkg/hostarch",
"//pkg/log",
"//pkg/sentry/vfs",
- "//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/sentry/fsimpl/eventfd/eventfd.go b/pkg/sentry/fsimpl/eventfd/eventfd.go
index 4f79cfcb7..af5ba5131 100644
--- a/pkg/sentry/fsimpl/eventfd/eventfd.go
+++ b/pkg/sentry/fsimpl/eventfd/eventfd.go
@@ -22,11 +22,11 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -149,7 +149,7 @@ func (efd *EventFileDescription) hostReadLocked(ctx context.Context, dst usermem
var buf [8]byte
if _, err := unix.Read(efd.hostfd, buf[:]); err != nil {
if err == unix.EWOULDBLOCK {
- return syserror.ErrWouldBlock
+ return linuxerr.ErrWouldBlock
}
return err
}
@@ -167,7 +167,7 @@ func (efd *EventFileDescription) read(ctx context.Context, dst usermem.IOSequenc
// We can't complete the read if the value is currently zero.
if efd.val == 0 {
efd.mu.Unlock()
- return syserror.ErrWouldBlock
+ return linuxerr.ErrWouldBlock
}
// Update the value based on the mode the event is operating in.
@@ -200,7 +200,7 @@ func (efd *EventFileDescription) hostWriteLocked(val uint64) error {
hostarch.ByteOrder.PutUint64(buf[:], val)
_, err := unix.Write(efd.hostfd, buf[:])
if err == unix.EWOULDBLOCK {
- return syserror.ErrWouldBlock
+ return linuxerr.ErrWouldBlock
}
return err
}
@@ -232,7 +232,7 @@ func (efd *EventFileDescription) Signal(val uint64) error {
// uint64 minus 1.
if val > math.MaxUint64-1-efd.val {
efd.mu.Unlock()
- return syserror.ErrWouldBlock
+ return linuxerr.ErrWouldBlock
}
efd.val += val
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/BUILD
diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD
index 871df5984..05c4fbeb2 100644
--- a/pkg/sentry/fsimpl/fuse/BUILD
+++ b/pkg/sentry/fsimpl/fuse/BUILD
@@ -59,7 +59,6 @@ go_library(
"//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
"@org_golang_x_sys//unix:go_default_library",
@@ -84,7 +83,6 @@ go_test(
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
- "//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go
index dab1e779d..0f855ac59 100644
--- a/pkg/sentry/fsimpl/fuse/dev.go
+++ b/pkg/sentry/fsimpl/fuse/dev.go
@@ -23,7 +23,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -38,7 +37,7 @@ type fuseDevice struct{}
// Open implements vfs.Device.Open.
func (fuseDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
if !kernel.FUSEEnabled {
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
var fd DeviceFD
@@ -126,7 +125,7 @@ func (fd *DeviceFD) PRead(ctx context.Context, dst usermem.IOSequence, offset in
return 0, linuxerr.EPERM
}
- return 0, syserror.ENOSYS
+ return 0, linuxerr.ENOSYS
}
// Read implements vfs.FileDescriptionImpl.Read.
@@ -192,7 +191,7 @@ func (fd *DeviceFD) readLocked(ctx context.Context, dst usermem.IOSequence, opts
}
if req == nil {
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
// We already checked the size: dst must be able to fit the whole request.
@@ -205,7 +204,7 @@ func (fd *DeviceFD) readLocked(ctx context.Context, dst usermem.IOSequence, opts
return 0, err
}
if n != len(req.data) {
- return 0, syserror.EIO
+ return 0, linuxerr.EIO
}
if req.hdr.Opcode == linux.FUSE_WRITE {
@@ -214,7 +213,7 @@ func (fd *DeviceFD) readLocked(ctx context.Context, dst usermem.IOSequence, opts
return 0, err
}
if written != len(req.payload) {
- return 0, syserror.EIO
+ return 0, linuxerr.EIO
}
n += int(written)
}
@@ -238,7 +237,7 @@ func (fd *DeviceFD) PWrite(ctx context.Context, src usermem.IOSequence, offset i
return 0, linuxerr.EPERM
}
- return 0, syserror.ENOSYS
+ return 0, linuxerr.ENOSYS
}
// Write implements vfs.FileDescriptionImpl.Write.
@@ -395,7 +394,7 @@ func (fd *DeviceFD) Seek(ctx context.Context, offset int64, whence int32) (int64
return 0, linuxerr.EPERM
}
- return 0, syserror.ENOSYS
+ return 0, linuxerr.ENOSYS
}
// sendResponse sends a response to the waiting task (if any).
diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go
index 04250d796..8951b5ba8 100644
--- a/pkg/sentry/fsimpl/fuse/dev_test.go
+++ b/pkg/sentry/fsimpl/fuse/dev_test.go
@@ -20,11 +20,11 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -186,7 +186,7 @@ func ReadTest(serverTask *kernel.Task, fd *vfs.FileDescription, inIOseq usermem.
// "would block".
n, err = dev.Read(serverTask, inIOseq, vfs.ReadOptions{})
total += n
- if err != syserror.ErrWouldBlock {
+ if err != linuxerr.ErrWouldBlock {
break
}
diff --git a/pkg/sentry/fsimpl/fuse/directory.go b/pkg/sentry/fsimpl/fuse/directory.go
index fcc5d9a2a..9611edd5a 100644
--- a/pkg/sentry/fsimpl/fuse/directory.go
+++ b/pkg/sentry/fsimpl/fuse/directory.go
@@ -19,10 +19,10 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -32,27 +32,27 @@ type directoryFD struct {
// Allocate implements directoryFD.Allocate.
func (*directoryFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
- return syserror.EISDIR
+ return linuxerr.EISDIR
}
// PRead implements vfs.FileDescriptionImpl.PRead.
func (*directoryFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
- return 0, syserror.EISDIR
+ return 0, linuxerr.EISDIR
}
// Read implements vfs.FileDescriptionImpl.Read.
func (*directoryFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
- return 0, syserror.EISDIR
+ return 0, linuxerr.EISDIR
}
// PWrite implements vfs.FileDescriptionImpl.PWrite.
func (*directoryFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
- return 0, syserror.EISDIR
+ return 0, linuxerr.EISDIR
}
// Write implements vfs.FileDescriptionImpl.Write.
func (*directoryFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
- return 0, syserror.EISDIR
+ return 0, linuxerr.EISDIR
}
// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go
index 172cbd88f..af16098d2 100644
--- a/pkg/sentry/fsimpl/fuse/fusefs.go
+++ b/pkg/sentry/fsimpl/fuse/fusefs.go
@@ -30,7 +30,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -612,7 +611,7 @@ func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMo
return nil, err
}
if opcode != linux.FUSE_LOOKUP && ((out.Attr.Mode&linux.S_IFMT)^uint32(fileType) != 0 || out.NodeID == 0 || out.NodeID == linux.FUSE_ROOT_ID) {
- return nil, syserror.EIO
+ return nil, linuxerr.EIO
}
child := i.fs.newInode(ctx, out.NodeID, out.Attr)
return child, nil
diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go
index 35d0ab6f4..fe119aa43 100644
--- a/pkg/sentry/fsimpl/fuse/read_write.go
+++ b/pkg/sentry/fsimpl/fuse/read_write.go
@@ -25,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/syserror"
)
// ReadInPages sends FUSE_READ requests for the size after round it up to
@@ -221,7 +220,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64,
// Write more than requested? EIO.
if out.Size > toWrite {
- return 0, syserror.EIO
+ return 0, linuxerr.EIO
}
written += out.Size
diff --git a/pkg/sentry/fsimpl/fuse/regular_file.go b/pkg/sentry/fsimpl/fuse/regular_file.go
index 6c4de3507..38cde8208 100644
--- a/pkg/sentry/fsimpl/fuse/regular_file.go
+++ b/pkg/sentry/fsimpl/fuse/regular_file.go
@@ -24,7 +24,6 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -108,7 +107,7 @@ func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs
return 0, err
}
if int64(cp) != toCopy {
- return 0, syserror.EIO
+ return 0, linuxerr.EIO
}
copied += toCopy
}
@@ -205,7 +204,7 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off
return 0, offset, err
}
if int64(cp) != srclen {
- return 0, offset, syserror.EIO
+ return 0, offset, linuxerr.EIO
}
n, err := fd.inode().fs.Write(ctx, fd, uint64(offset), uint32(srclen), data)
@@ -216,7 +215,7 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off
if n == 0 {
// We have checked srclen != 0 previously.
// If err == nil, then it's a short write and we return EIO.
- return 0, offset, syserror.EIO
+ return 0, offset, linuxerr.EIO
}
written = int64(n)
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
index 752060044..509dd0e1a 100644
--- a/pkg/sentry/fsimpl/gofer/BUILD
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -54,7 +54,10 @@ go_library(
"//pkg/fdnotifier",
"//pkg/fspath",
"//pkg/hostarch",
+ "//pkg/lisafs",
"//pkg/log",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/metric",
"//pkg/p9",
"//pkg/refs",
@@ -79,7 +82,6 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/sync",
"//pkg/syserr",
- "//pkg/syserror",
"//pkg/unet",
"//pkg/usermem",
"//pkg/waiter",
diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go
index 5c48a9fee..d99a6112c 100644
--- a/pkg/sentry/fsimpl/gofer/directory.go
+++ b/pkg/sentry/fsimpl/gofer/directory.go
@@ -222,47 +222,88 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) {
off := uint64(0)
const count = 64 * 1024 // for consistency with the vfs1 client
d.handleMu.RLock()
- if d.readFile.isNil() {
+ if !d.isReadFileOk() {
// This should not be possible because a readable handle should
// have been opened when the calling directoryFD was opened.
d.handleMu.RUnlock()
panic("gofer.dentry.getDirents called without a readable handle")
}
+ // shouldSeek0 indicates whether the server should SEEK to 0 before reading
+ // directory entries.
+ shouldSeek0 := true
for {
- p9ds, err := d.readFile.readdir(ctx, off, count)
- if err != nil {
- d.handleMu.RUnlock()
- return nil, err
- }
- if len(p9ds) == 0 {
- d.handleMu.RUnlock()
- break
- }
- for _, p9d := range p9ds {
- if p9d.Name == "." || p9d.Name == ".." {
- continue
+ if d.fs.opts.lisaEnabled {
+ countLisa := int32(count)
+ if shouldSeek0 {
+ // See lisafs.Getdents64Req.Count.
+ countLisa = -countLisa
+ shouldSeek0 = false
+ }
+ lisafsDs, err := d.readFDLisa.Getdents64(ctx, countLisa)
+ if err != nil {
+ d.handleMu.RUnlock()
+ return nil, err
+ }
+ if len(lisafsDs) == 0 {
+ d.handleMu.RUnlock()
+ break
+ }
+ for i := range lisafsDs {
+ name := string(lisafsDs[i].Name)
+ if name == "." || name == ".." {
+ continue
+ }
+ dirent := vfs.Dirent{
+ Name: name,
+ Ino: d.fs.inoFromKey(inoKey{
+ ino: uint64(lisafsDs[i].Ino),
+ devMinor: uint32(lisafsDs[i].DevMinor),
+ devMajor: uint32(lisafsDs[i].DevMajor),
+ }),
+ NextOff: int64(len(dirents) + 1),
+ Type: uint8(lisafsDs[i].Type),
+ }
+ dirents = append(dirents, dirent)
+ if realChildren != nil {
+ realChildren[name] = struct{}{}
+ }
}
- dirent := vfs.Dirent{
- Name: p9d.Name,
- Ino: d.fs.inoFromQIDPath(p9d.QID.Path),
- NextOff: int64(len(dirents) + 1),
+ } else {
+ p9ds, err := d.readFile.readdir(ctx, off, count)
+ if err != nil {
+ d.handleMu.RUnlock()
+ return nil, err
}
- // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or
- // DMSOCKET.
- switch p9d.Type {
- case p9.TypeSymlink:
- dirent.Type = linux.DT_LNK
- case p9.TypeDir:
- dirent.Type = linux.DT_DIR
- default:
- dirent.Type = linux.DT_REG
+ if len(p9ds) == 0 {
+ d.handleMu.RUnlock()
+ break
}
- dirents = append(dirents, dirent)
- if realChildren != nil {
- realChildren[p9d.Name] = struct{}{}
+ for _, p9d := range p9ds {
+ if p9d.Name == "." || p9d.Name == ".." {
+ continue
+ }
+ dirent := vfs.Dirent{
+ Name: p9d.Name,
+ Ino: d.fs.inoFromQIDPath(p9d.QID.Path),
+ NextOff: int64(len(dirents) + 1),
+ }
+ // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or
+ // DMSOCKET.
+ switch p9d.Type {
+ case p9.TypeSymlink:
+ dirent.Type = linux.DT_LNK
+ case p9.TypeDir:
+ dirent.Type = linux.DT_DIR
+ default:
+ dirent.Type = linux.DT_REG
+ }
+ dirents = append(dirents, dirent)
+ if realChildren != nil {
+ realChildren[p9d.Name] = struct{}{}
+ }
}
+ off = p9ds[len(p9ds)-1].Offset
}
- off = p9ds[len(p9ds)-1].Offset
}
}
// Emit entries for synthetic children.
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 05b776c2e..f7b3446d3 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -21,10 +21,12 @@ import (
"sync"
"sync/atomic"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/lisafs"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/host"
"gvisor.dev/gvisor/pkg/sentry/fsmetric"
@@ -33,7 +35,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
)
// Sync implements vfs.FilesystemImpl.Sync.
@@ -54,9 +55,47 @@ func (fs *filesystem) Sync(ctx context.Context) error {
// regardless.
var retErr error
+ if fs.opts.lisaEnabled {
+ // Try accumulating all FDIDs to fsync and fsync then via one RPC as
+ // opposed to making an RPC per FDID. Passing a non-nil accFsyncFDIDs to
+ // dentry.syncCachedFile() and specialFileFD.sync() will cause them to not
+ // make an RPC, instead accumulate syncable FDIDs in the passed slice.
+ accFsyncFDIDs := make([]lisafs.FDID, 0, len(ds)+len(sffds))
+
+ // Sync syncable dentries.
+ for _, d := range ds {
+ if err := d.syncCachedFile(ctx, true /* forFilesystemSync */, &accFsyncFDIDs); err != nil {
+ ctx.Infof("gofer.filesystem.Sync: dentry.syncCachedFile failed: %v", err)
+ if retErr == nil {
+ retErr = err
+ }
+ }
+ }
+
+ // Sync special files, which may be writable but do not use dentry shared
+ // handles (so they won't be synced by the above).
+ for _, sffd := range sffds {
+ if err := sffd.sync(ctx, true /* forFilesystemSync */, &accFsyncFDIDs); err != nil {
+ ctx.Infof("gofer.filesystem.Sync: specialFileFD.sync failed: %v", err)
+ if retErr == nil {
+ retErr = err
+ }
+ }
+ }
+
+ if err := fs.clientLisa.SyncFDs(ctx, accFsyncFDIDs); err != nil {
+ ctx.Infof("gofer.filesystem.Sync: fs.fsyncMultipleFDLisa failed: %v", err)
+ if retErr == nil {
+ retErr = err
+ }
+ }
+
+ return retErr
+ }
+
// Sync syncable dentries.
for _, d := range ds {
- if err := d.syncCachedFile(ctx, true /* forFilesystemSync */); err != nil {
+ if err := d.syncCachedFile(ctx, true /* forFilesystemSync */, nil /* accFsyncFDIDsLisa */); err != nil {
ctx.Infof("gofer.filesystem.Sync: dentry.syncCachedFile failed: %v", err)
if retErr == nil {
retErr = err
@@ -67,7 +106,7 @@ func (fs *filesystem) Sync(ctx context.Context) error {
// Sync special files, which may be writable but do not use dentry shared
// handles (so they won't be synced by the above).
for _, sffd := range sffds {
- if err := sffd.sync(ctx, true /* forFilesystemSync */); err != nil {
+ if err := sffd.sync(ctx, true /* forFilesystemSync */, nil /* accFsyncFDIDsLisa */); err != nil {
ctx.Infof("gofer.filesystem.Sync: specialFileFD.sync failed: %v", err)
if retErr == nil {
retErr = err
@@ -198,7 +237,13 @@ afterSymlink:
rp.Advance()
return d.parent, followedSymlink, nil
}
- child, err := fs.getChildLocked(ctx, d, name, ds)
+ var child *dentry
+ var err error
+ if fs.opts.lisaEnabled {
+ child, err = fs.getChildAndWalkPathLocked(ctx, d, rp, ds)
+ } else {
+ child, err = fs.getChildLocked(ctx, d, name, ds)
+ }
if err != nil {
return nil, false, err
}
@@ -220,6 +265,99 @@ afterSymlink:
return child, followedSymlink, nil
}
+// Preconditions:
+// * fs.opts.lisaEnabled.
+// * fs.renameMu must be locked.
+// * parent.dirMu must be locked.
+// * parent.isDir().
+// * parent and the dentry at name have been revalidated.
+func (fs *filesystem) getChildAndWalkPathLocked(ctx context.Context, parent *dentry, rp *vfs.ResolvingPath, ds **[]*dentry) (*dentry, error) {
+ // Note that pit is a copy of the iterator that does not affect rp.
+ pit := rp.Pit()
+ first := pit.String()
+ if len(first) > maxFilenameLen {
+ return nil, linuxerr.ENAMETOOLONG
+ }
+ if child, ok := parent.children[first]; ok || parent.isSynthetic() {
+ if child == nil {
+ return nil, linuxerr.ENOENT
+ }
+ return child, nil
+ }
+
+ // Walk as much of the path as possible in 1 RPC.
+ names := []string{first}
+ for pit = pit.Next(); pit.Ok(); pit = pit.Next() {
+ name := pit.String()
+ if name == "." {
+ continue
+ }
+ if name == ".." {
+ break
+ }
+ names = append(names, name)
+ }
+ status, inodes, err := parent.controlFDLisa.WalkMultiple(ctx, names)
+ if err != nil {
+ return nil, err
+ }
+ if len(inodes) == 0 {
+ parent.cacheNegativeLookupLocked(first)
+ return nil, linuxerr.ENOENT
+ }
+
+ // Add the walked inodes into the dentry tree.
+ curParent := parent
+ curParentDirMuLock := func() {
+ if curParent != parent {
+ curParent.dirMu.Lock()
+ }
+ }
+ curParentDirMuUnlock := func() {
+ if curParent != parent {
+ curParent.dirMu.Unlock() // +checklocksforce: locked via curParentDirMuLock().
+ }
+ }
+ var ret *dentry
+ var dentryCreationErr error
+ for i := range inodes {
+ if dentryCreationErr != nil {
+ fs.clientLisa.CloseFDBatched(ctx, inodes[i].ControlFD)
+ continue
+ }
+
+ child, err := fs.newDentryLisa(ctx, &inodes[i])
+ if err != nil {
+ fs.clientLisa.CloseFDBatched(ctx, inodes[i].ControlFD)
+ dentryCreationErr = err
+ continue
+ }
+ curParentDirMuLock()
+ curParent.cacheNewChildLocked(child, names[i])
+ curParentDirMuUnlock()
+ // For now, child has 0 references, so our caller should call
+ // child.checkCachingLocked(). curParent gained a ref so we should also
+ // call curParent.checkCachingLocked() so it can be removed from the cache
+ // if needed. We only do that for the first iteration because all
+ // subsequent parents would have already been added to ds.
+ if i == 0 {
+ *ds = appendDentry(*ds, curParent)
+ }
+ *ds = appendDentry(*ds, child)
+ curParent = child
+ if i == 0 {
+ ret = child
+ }
+ }
+
+ if status == lisafs.WalkComponentDoesNotExist && curParent.isDir() {
+ curParentDirMuLock()
+ curParent.cacheNegativeLookupLocked(names[len(inodes)])
+ curParentDirMuUnlock()
+ }
+ return ret, dentryCreationErr
+}
+
// getChildLocked returns a dentry representing the child of parent with the
// given name. Returns ENOENT if the child doesn't exist.
//
@@ -228,32 +366,47 @@ afterSymlink:
// * parent.dirMu must be locked.
// * parent.isDir().
// * name is not "." or "..".
-// * dentry at name has been revalidated
+// * parent and the dentry at name have been revalidated.
func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
if len(name) > maxFilenameLen {
return nil, linuxerr.ENAMETOOLONG
}
if child, ok := parent.children[name]; ok || parent.isSynthetic() {
if child == nil {
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
return child, nil
}
- qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name)
- if err != nil {
- if linuxerr.Equals(linuxerr.ENOENT, err) {
- parent.cacheNegativeLookupLocked(name)
+ var child *dentry
+ if fs.opts.lisaEnabled {
+ childInode, err := parent.controlFDLisa.Walk(ctx, name)
+ if err != nil {
+ if linuxerr.Equals(linuxerr.ENOENT, err) {
+ parent.cacheNegativeLookupLocked(name)
+ }
+ return nil, err
+ }
+ // Create a new dentry representing the file.
+ child, err = fs.newDentryLisa(ctx, childInode)
+ if err != nil {
+ fs.clientLisa.CloseFDBatched(ctx, childInode.ControlFD)
+ return nil, err
+ }
+ } else {
+ qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name)
+ if err != nil {
+ if linuxerr.Equals(linuxerr.ENOENT, err) {
+ parent.cacheNegativeLookupLocked(name)
+ }
+ return nil, err
+ }
+ // Create a new dentry representing the file.
+ child, err = fs.newDentry(ctx, file, qid, attrMask, &attr)
+ if err != nil {
+ file.close(ctx)
+ return nil, err
}
- return nil, err
- }
-
- // Create a new dentry representing the file.
- child, err := fs.newDentry(ctx, file, qid, attrMask, &attr)
- if err != nil {
- file.close(ctx)
- delete(parent.children, name)
- return nil, err
}
parent.cacheNewChildLocked(child, name)
appendNewChildDentry(ds, parent, child)
@@ -329,7 +482,7 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath,
// Preconditions:
// * !rp.Done().
// * For the final path component in rp, !rp.ShouldFollowSymlink().
-func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, createInRemoteDir func(parent *dentry, name string, ds **[]*dentry) error, createInSyntheticDir func(parent *dentry, name string) error) error {
+func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, createInRemoteDir func(parent *dentry, name string, ds **[]*dentry) (*lisafs.Inode, error), createInSyntheticDir func(parent *dentry, name string) error, updateChild func(child *dentry)) error {
var ds *[]*dentry
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
@@ -349,7 +502,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
return linuxerr.EEXIST
}
if parent.isDeleted() {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if err := fs.revalidateOne(ctx, rp.VirtualFilesystem(), parent, name, &ds); err != nil {
return err
@@ -395,7 +548,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
return err
}
if !dir && rp.MustBeDir() {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if parent.isSynthetic() {
if createInSyntheticDir == nil {
@@ -416,9 +569,26 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
// No cached dentry exists; however, in InteropModeShared there might still be
// an existing file at name. Just attempt the file creation RPC anyways. If a
// file does exist, the RPC will fail with EEXIST like we would have.
- if err := createInRemoteDir(parent, name, &ds); err != nil {
+ lisaInode, err := createInRemoteDir(parent, name, &ds)
+ if err != nil {
return err
}
+ // lisafs may aggresively cache newly created inodes. This has helped reduce
+ // Walk RPCs in practice.
+ if lisaInode != nil {
+ child, err := fs.newDentryLisa(ctx, lisaInode)
+ if err != nil {
+ fs.clientLisa.CloseFDBatched(ctx, lisaInode.ControlFD)
+ return err
+ }
+ parent.cacheNewChildLocked(child, name)
+ appendNewChildDentry(&ds, parent, child)
+
+ // lisafs may update dentry properties upon successful creation.
+ if updateChild != nil {
+ updateChild(child)
+ }
+ }
if fs.opts.interop != InteropModeShared {
if child, ok := parent.children[name]; ok && child == nil {
// Delete the now-stale negative dentry.
@@ -463,7 +633,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
}
} else {
if name == "." || name == ".." {
- return syserror.EISDIR
+ return linuxerr.EISDIR
}
}
@@ -486,7 +656,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
child, ok = parent.children[name]
if ok && child == nil {
// Hit a negative cached entry, child doesn't exist.
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
} else {
child, _, err = fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
@@ -552,7 +722,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
// child must be a non-directory file.
if child != nil && child.isDir() {
vfsObj.AbortDeleteDentry(&child.vfsd) // +checklocksforce: see above.
- return syserror.EISDIR
+ return linuxerr.EISDIR
}
if rp.MustBeDir() {
if child != nil {
@@ -563,10 +733,14 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
}
if parent.isSynthetic() {
if child == nil {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
} else if child == nil || !child.isSynthetic() {
- err = parent.file.unlinkAt(ctx, name, flags)
+ if fs.opts.lisaEnabled {
+ err = parent.controlFDLisa.UnlinkAt(ctx, name, flags)
+ } else {
+ err = parent.file.unlinkAt(ctx, name, flags)
+ }
if err != nil {
if child != nil {
vfsObj.AbortDeleteDentry(&child.vfsd) // +checklocksforce: see above.
@@ -659,40 +833,43 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa
// LinkAt implements vfs.FilesystemImpl.LinkAt.
func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
- return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, _ **[]*dentry) error {
+ err := fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, ds **[]*dentry) (*lisafs.Inode, error) {
if rp.Mount() != vd.Mount() {
- return linuxerr.EXDEV
+ return nil, linuxerr.EXDEV
}
d := vd.Dentry().Impl().(*dentry)
if d.isDir() {
- return linuxerr.EPERM
+ return nil, linuxerr.EPERM
}
gid := auth.KGID(atomic.LoadUint32(&d.gid))
uid := auth.KUID(atomic.LoadUint32(&d.uid))
mode := linux.FileMode(atomic.LoadUint32(&d.mode))
if err := vfs.MayLink(rp.Credentials(), mode, uid, gid); err != nil {
- return err
+ return nil, err
}
if d.nlink == 0 {
- return syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
if d.nlink == math.MaxUint32 {
- return linuxerr.EMLINK
+ return nil, linuxerr.EMLINK
}
- if err := parent.file.link(ctx, d.file, childName); err != nil {
- return err
+ if fs.opts.lisaEnabled {
+ return parent.controlFDLisa.LinkAt(ctx, d.controlFDLisa.ID(), childName)
}
+ return nil, parent.file.link(ctx, d.file, childName)
+ }, nil, nil)
+ if err == nil {
// Success!
- atomic.AddUint32(&d.nlink, 1)
- return nil
- }, nil)
+ vd.Dentry().Impl().(*dentry).incLinks()
+ }
+ return err
}
// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
creds := rp.Credentials()
- return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, ds **[]*dentry) error {
+ return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, ds **[]*dentry) (*lisafs.Inode, error) {
// If the parent is a setgid directory, use the parent's GID
// rather than the caller's and enable setgid.
kgid := creds.EffectiveKGID
@@ -701,9 +878,18 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
kgid = auth.KGID(atomic.LoadUint32(&parent.gid))
mode |= linux.S_ISGID
}
- if _, err := parent.file.mkdir(ctx, name, p9.FileMode(mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid)); err != nil {
+ var (
+ childDirInode *lisafs.Inode
+ err error
+ )
+ if fs.opts.lisaEnabled {
+ childDirInode, err = parent.controlFDLisa.MkdirAt(ctx, name, mode, lisafs.UID(creds.EffectiveKUID), lisafs.GID(kgid))
+ } else {
+ _, err = parent.file.mkdir(ctx, name, p9.FileMode(mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid))
+ }
+ if err != nil {
if !opts.ForSyntheticMountpoint || linuxerr.Equals(linuxerr.EEXIST, err) {
- return err
+ return nil, err
}
ctx.Infof("Failed to create remote directory %q: %v; falling back to synthetic directory", name, err)
parent.createSyntheticChildLocked(&createSyntheticOpts{
@@ -717,7 +903,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
if fs.opts.interop != InteropModeShared {
parent.incLinks()
}
- return nil
+ return childDirInode, nil
}, func(parent *dentry, name string) error {
if !opts.ForSyntheticMountpoint {
// Can't create non-synthetic files in synthetic directories.
@@ -731,16 +917,26 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
})
parent.incLinks()
return nil
- })
+ }, nil)
}
// MknodAt implements vfs.FilesystemImpl.MknodAt.
func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
- return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, ds **[]*dentry) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, ds **[]*dentry) (*lisafs.Inode, error) {
creds := rp.Credentials()
- _, err := parent.file.mknod(ctx, name, (p9.FileMode)(opts.Mode), opts.DevMajor, opts.DevMinor, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
- if !linuxerr.Equals(linuxerr.EPERM, err) {
- return err
+ var (
+ childInode *lisafs.Inode
+ err error
+ )
+ if fs.opts.lisaEnabled {
+ childInode, err = parent.controlFDLisa.MknodAt(ctx, name, opts.Mode, lisafs.UID(creds.EffectiveKUID), lisafs.GID(creds.EffectiveKGID), opts.DevMinor, opts.DevMajor)
+ } else {
+ _, err = parent.file.mknod(ctx, name, (p9.FileMode)(opts.Mode), opts.DevMajor, opts.DevMinor, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
+ }
+ if err == nil {
+ return childInode, nil
+ } else if !linuxerr.Equals(linuxerr.EPERM, err) {
+ return nil, err
}
// EPERM means that gofer does not allow creating a socket or pipe. Fallback
@@ -751,10 +947,10 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
switch {
case err == nil:
// Step succeeded, another file exists.
- return linuxerr.EEXIST
+ return nil, linuxerr.EEXIST
case !linuxerr.Equals(linuxerr.ENOENT, err):
// Unexpected error.
- return err
+ return nil, err
}
switch opts.Mode.FileType() {
@@ -767,7 +963,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
endpoint: opts.Endpoint,
})
*ds = appendDentry(*ds, parent)
- return nil
+ return nil, nil
case linux.S_IFIFO:
parent.createSyntheticChildLocked(&createSyntheticOpts{
name: name,
@@ -777,11 +973,11 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize),
})
*ds = appendDentry(*ds, parent)
- return nil
+ return nil, nil
}
// Retain error from gofer if synthetic file cannot be created internally.
- return linuxerr.EPERM
- }, nil)
+ return nil, linuxerr.EPERM
+ }, nil, nil)
}
// OpenAt implements vfs.FilesystemImpl.OpenAt.
@@ -811,7 +1007,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if rp.Done() {
// Reject attempts to open mount root directory with O_CREAT.
if mayCreate && rp.MustBeDir() {
- return nil, syserror.EISDIR
+ return nil, linuxerr.EISDIR
}
if mustCreate {
return nil, linuxerr.EEXIST
@@ -841,7 +1037,7 @@ afterTrailingSymlink:
}
// Reject attempts to open directories with O_CREAT.
if mayCreate && rp.MustBeDir() {
- return nil, syserror.EISDIR
+ return nil, linuxerr.EISDIR
}
if err := fs.revalidateOne(ctx, rp.VirtualFilesystem(), parent, rp.Component(), &ds); err != nil {
return nil, err
@@ -922,11 +1118,11 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
case linux.S_IFDIR:
// Can't open directories with O_CREAT.
if opts.Flags&linux.O_CREAT != 0 {
- return nil, syserror.EISDIR
+ return nil, linuxerr.EISDIR
}
// Can't open directories writably.
if ats&vfs.MayWrite != 0 {
- return nil, syserror.EISDIR
+ return nil, linuxerr.EISDIR
}
if opts.Flags&linux.O_DIRECT != 0 {
return nil, linuxerr.EINVAL
@@ -987,6 +1183,23 @@ func (d *dentry) openSocketByConnecting(ctx context.Context, opts *vfs.OpenOptio
if opts.Flags&linux.O_DIRECT != 0 {
return nil, linuxerr.EINVAL
}
+ if d.fs.opts.lisaEnabled {
+ // Note that special value of linux.SockType = 0 is interpreted by lisafs
+ // as "do not care about the socket type". Analogous to p9.AnonymousSocket.
+ sockFD, err := d.controlFDLisa.Connect(ctx, 0 /* sockType */)
+ if err != nil {
+ return nil, err
+ }
+ fd, err := host.NewFD(ctx, kernel.KernelFromContext(ctx).HostMount(), sockFD, &host.NewFDOptions{
+ HaveFlags: true,
+ Flags: opts.Flags,
+ })
+ if err != nil {
+ unix.Close(sockFD)
+ return nil, err
+ }
+ return fd, nil
+ }
fdObj, err := d.file.connect(ctx, p9.AnonymousSocket)
if err != nil {
return nil, err
@@ -999,6 +1212,7 @@ func (d *dentry) openSocketByConnecting(ctx context.Context, opts *vfs.OpenOptio
fdObj.Close()
return nil, err
}
+ // Ownership has been transferred to fd.
fdObj.Release()
return fd, nil
}
@@ -1018,7 +1232,13 @@ func (d *dentry) openSpecialFile(ctx context.Context, mnt *vfs.Mount, opts *vfs.
// since closed its end.
isBlockingOpenOfNamedPipe := d.fileType() == linux.S_IFIFO && opts.Flags&linux.O_NONBLOCK == 0
retry:
- h, err := openHandle(ctx, d.file, ats.MayRead(), ats.MayWrite(), opts.Flags&linux.O_TRUNC != 0)
+ var h handle
+ var err error
+ if d.fs.opts.lisaEnabled {
+ h, err = openHandleLisa(ctx, d.controlFDLisa, ats.MayRead(), ats.MayWrite(), opts.Flags&linux.O_TRUNC != 0)
+ } else {
+ h, err = openHandle(ctx, d.file, ats.MayRead(), ats.MayWrite(), opts.Flags&linux.O_TRUNC != 0)
+ }
if err != nil {
if isBlockingOpenOfNamedPipe && ats == vfs.MayWrite && linuxerr.Equals(linuxerr.ENXIO, err) {
// An attempt to open a named pipe with O_WRONLY|O_NONBLOCK fails
@@ -1054,7 +1274,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
return nil, err
}
if d.isDeleted() {
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
mnt := rp.Mount()
if err := mnt.CheckBeginWrite(); err != nil {
@@ -1062,18 +1282,8 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
}
defer mnt.EndWrite()
- // 9P2000.L's lcreate takes a fid representing the parent directory, and
- // converts it into an open fid representing the created file, so we need
- // to duplicate the directory fid first.
- _, dirfile, err := d.file.walk(ctx, nil)
- if err != nil {
- return nil, err
- }
creds := rp.Credentials()
name := rp.Component()
- // We only want the access mode for creating the file.
- createFlags := p9.OpenFlags(opts.Flags) & p9.OpenFlagsModeMask
-
// If the parent is a setgid directory, use the parent's GID rather
// than the caller's.
kgid := creds.EffectiveKGID
@@ -1081,51 +1291,87 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
kgid = auth.KGID(atomic.LoadUint32(&d.gid))
}
- fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, createFlags, p9.FileMode(opts.Mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid))
- if err != nil {
- dirfile.close(ctx)
- return nil, err
- }
- // Then we need to walk to the file we just created to get a non-open fid
- // representing it, and to get its metadata. This must use d.file since, as
- // explained above, dirfile was invalidated by dirfile.Create().
- _, nonOpenFile, attrMask, attr, err := d.file.walkGetAttrOne(ctx, name)
- if err != nil {
- openFile.close(ctx)
- if fdobj != nil {
- fdobj.Close()
+ var child *dentry
+ var openP9File p9file
+ openLisaFD := lisafs.InvalidFDID
+ openHostFD := int32(-1)
+ if d.fs.opts.lisaEnabled {
+ ino, openFD, hostFD, err := d.controlFDLisa.OpenCreateAt(ctx, name, opts.Flags&linux.O_ACCMODE, opts.Mode, lisafs.UID(creds.EffectiveKUID), lisafs.GID(kgid))
+ if err != nil {
+ return nil, err
+ }
+ openHostFD = int32(hostFD)
+ openLisaFD = openFD
+
+ child, err = d.fs.newDentryLisa(ctx, &ino)
+ if err != nil {
+ d.fs.clientLisa.CloseFDBatched(ctx, ino.ControlFD)
+ d.fs.clientLisa.CloseFDBatched(ctx, openFD)
+ if hostFD >= 0 {
+ unix.Close(hostFD)
+ }
+ return nil, err
+ }
+ } else {
+ // 9P2000.L's lcreate takes a fid representing the parent directory, and
+ // converts it into an open fid representing the created file, so we need
+ // to duplicate the directory fid first.
+ _, dirfile, err := d.file.walk(ctx, nil)
+ if err != nil {
+ return nil, err
+ }
+ // We only want the access mode for creating the file.
+ createFlags := p9.OpenFlags(opts.Flags) & p9.OpenFlagsModeMask
+
+ fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, createFlags, p9.FileMode(opts.Mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid))
+ if err != nil {
+ dirfile.close(ctx)
+ return nil, err
+ }
+ // Then we need to walk to the file we just created to get a non-open fid
+ // representing it, and to get its metadata. This must use d.file since, as
+ // explained above, dirfile was invalidated by dirfile.Create().
+ _, nonOpenFile, attrMask, attr, err := d.file.walkGetAttrOne(ctx, name)
+ if err != nil {
+ openFile.close(ctx)
+ if fdobj != nil {
+ fdobj.Close()
+ }
+ return nil, err
+ }
+
+ // Construct the new dentry.
+ child, err = d.fs.newDentry(ctx, nonOpenFile, createQID, attrMask, &attr)
+ if err != nil {
+ nonOpenFile.close(ctx)
+ openFile.close(ctx)
+ if fdobj != nil {
+ fdobj.Close()
+ }
+ return nil, err
}
- return nil, err
- }
- // Construct the new dentry.
- child, err := d.fs.newDentry(ctx, nonOpenFile, createQID, attrMask, &attr)
- if err != nil {
- nonOpenFile.close(ctx)
- openFile.close(ctx)
if fdobj != nil {
- fdobj.Close()
+ openHostFD = int32(fdobj.Release())
}
- return nil, err
+ openP9File = openFile
}
// Incorporate the fid that was opened by lcreate.
useRegularFileFD := child.fileType() == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD
if useRegularFileFD {
- openFD := int32(-1)
- if fdobj != nil {
- openFD = int32(fdobj.Release())
- }
child.handleMu.Lock()
if vfs.MayReadFileWithOpenFlags(opts.Flags) {
- child.readFile = openFile
- if fdobj != nil {
- child.readFD = openFD
- child.mmapFD = openFD
+ child.readFile = openP9File
+ child.readFDLisa = d.fs.clientLisa.NewFD(openLisaFD)
+ if openHostFD != -1 {
+ child.readFD = openHostFD
+ child.mmapFD = openHostFD
}
}
if vfs.MayWriteFileWithOpenFlags(opts.Flags) {
- child.writeFile = openFile
- child.writeFD = openFD
+ child.writeFile = openP9File
+ child.writeFDLisa = d.fs.clientLisa.NewFD(openLisaFD)
+ child.writeFD = openHostFD
}
child.handleMu.Unlock()
}
@@ -1147,11 +1393,9 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
childVFSFD = &fd.vfsfd
} else {
h := handle{
- file: openFile,
- fd: -1,
- }
- if fdobj != nil {
- h.fd = int32(fdobj.Release())
+ file: openP9File,
+ fdLisa: d.fs.clientLisa.NewFD(openLisaFD),
+ fd: openHostFD,
}
fd, err := newSpecialFileFD(h, mnt, child, opts.Flags)
if err != nil {
@@ -1268,7 +1512,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
defer newParent.dirMu.Unlock()
}
if newParent.isDeleted() {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
replaced, err := fs.getChildLocked(ctx, newParent, newName, &ds)
if err != nil && !linuxerr.Equals(linuxerr.ENOENT, err) {
@@ -1282,7 +1526,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
replacedVFSD = &replaced.vfsd
if replaced.isDir() {
if !renamed.isDir() {
- return syserror.EISDIR
+ return linuxerr.EISDIR
}
if genericIsAncestorDentry(replaced, renamed) {
return linuxerr.ENOTEMPTY
@@ -1305,7 +1549,12 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
// Update the remote filesystem.
if !renamed.isSynthetic() {
- if err := renamed.file.rename(ctx, newParent.file, newName); err != nil {
+ if fs.opts.lisaEnabled {
+ err = renamed.controlFDLisa.RenameTo(ctx, newParent.controlFDLisa.ID(), newName)
+ } else {
+ err = renamed.file.rename(ctx, newParent.file, newName)
+ }
+ if err != nil {
vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD)
return err
}
@@ -1316,7 +1565,12 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if replaced.isDir() {
flags = linux.AT_REMOVEDIR
}
- if err := newParent.file.unlinkAt(ctx, newName, flags); err != nil {
+ if fs.opts.lisaEnabled {
+ err = newParent.controlFDLisa.UnlinkAt(ctx, newName, flags)
+ } else {
+ err = newParent.file.unlinkAt(ctx, newName, flags)
+ }
+ if err != nil {
vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD)
return err
}
@@ -1432,6 +1686,28 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu
for d.isSynthetic() {
d = d.parent
}
+ if fs.opts.lisaEnabled {
+ var statFS lisafs.StatFS
+ if err := d.controlFDLisa.StatFSTo(ctx, &statFS); err != nil {
+ return linux.Statfs{}, err
+ }
+ if statFS.NameLength > maxFilenameLen {
+ statFS.NameLength = maxFilenameLen
+ }
+ return linux.Statfs{
+ // This is primarily for distinguishing a gofer file system in
+ // tests. Testing is important, so instead of defining
+ // something completely random, use a standard value.
+ Type: linux.V9FS_MAGIC,
+ BlockSize: statFS.BlockSize,
+ Blocks: statFS.Blocks,
+ BlocksFree: statFS.BlocksFree,
+ BlocksAvailable: statFS.BlocksAvailable,
+ Files: statFS.Files,
+ FilesFree: statFS.FilesFree,
+ NameLength: statFS.NameLength,
+ }, nil
+ }
fsstat, err := d.file.statFS(ctx)
if err != nil {
return linux.Statfs{}, err
@@ -1457,11 +1733,21 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu
// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
- return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, _ **[]*dentry) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, ds **[]*dentry) (*lisafs.Inode, error) {
creds := rp.Credentials()
+ if fs.opts.lisaEnabled {
+ return parent.controlFDLisa.SymlinkAt(ctx, name, target, lisafs.UID(creds.EffectiveKUID), lisafs.GID(creds.EffectiveKGID))
+ }
_, err := parent.file.symlink(ctx, target, name, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
- return err
- }, nil)
+ return nil, err
+ }, nil, func(child *dentry) {
+ if fs.opts.interop != InteropModeShared {
+ // lisafs caches the symlink target on creation. In practice, this
+ // helps avoid a lot of ReadLink RPCs.
+ child.haveTarget = true
+ child.target = target
+ }
+ })
}
// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
@@ -1506,7 +1792,7 @@ func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, si
if err != nil {
return nil, err
}
- return d.listXattr(ctx, rp.Credentials(), size)
+ return d.listXattr(ctx, size)
}
// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt.
@@ -1613,6 +1899,9 @@ func (fs *filesystem) MountOptions() string {
if fs.opts.overlayfsStaleRead {
optsKV = append(optsKV, mopt{moptOverlayfsStaleRead, nil})
}
+ if fs.opts.lisaEnabled {
+ optsKV = append(optsKV, mopt{moptLisafs, nil})
+ }
opts := make([]string, 0, len(optsKV))
for _, opt := range optsKV {
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 25d2e39d6..b98825e26 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -48,6 +48,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/lisafs"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
refs_vfs1 "gvisor.dev/gvisor/pkg/refs"
@@ -62,7 +63,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/unet"
)
@@ -84,6 +84,7 @@ const (
moptForcePageCache = "force_page_cache"
moptLimitHostFDTranslation = "limit_host_fd_translation"
moptOverlayfsStaleRead = "overlayfs_stale_read"
+ moptLisafs = "lisafs"
)
// Valid values for the "cache" mount option.
@@ -119,6 +120,10 @@ type filesystem struct {
// client is the client used by this filesystem. client is immutable.
client *p9.Client `state:"nosave"`
+ // clientLisa is the client used for communicating with the server when
+ // lisafs is enabled. lisafsCient is immutable.
+ clientLisa *lisafs.Client `state:"nosave"`
+
// clock is a realtime clock used to set timestamps in file operations.
clock ktime.Clock
@@ -162,6 +167,12 @@ type filesystem struct {
inoMu sync.Mutex `state:"nosave"`
inoByQIDPath map[uint64]uint64 `state:"nosave"`
+ // inoByKey is the same as inoByQIDPath but only used by lisafs. It helps
+ // identify inodes based on the device ID and host inode number provided
+ // by the gofer process. It is not preserved across checkpoint/restore for
+ // the same reason as above. inoByKey is protected by inoMu.
+ inoByKey map[inoKey]uint64 `state:"nosave"`
+
// lastIno is the last inode number assigned to a file. lastIno is accessed
// using atomic memory operations.
lastIno uint64
@@ -215,6 +226,10 @@ type filesystemOptions struct {
// way that application FDs representing "special files" such as sockets
// do. Note that this disables client caching and mmap for regular files.
regularFilesUseSpecialFileFD bool
+
+ // lisaEnabled indicates whether the client will use lisafs protocol to
+ // communicate with the server instead of 9P.
+ lisaEnabled bool
}
// InteropMode controls the client's interaction with other remote filesystem
@@ -428,6 +443,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
delete(mopts, moptOverlayfsStaleRead)
fsopts.overlayfsStaleRead = true
}
+ if lisafs, ok := mopts[moptLisafs]; ok {
+ delete(mopts, moptLisafs)
+ fsopts.lisaEnabled, err = strconv.ParseBool(lisafs)
+ if err != nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid lisafs option: %s", lisafs)
+ return nil, nil, linuxerr.EINVAL
+ }
+ }
// fsopts.regularFilesUseSpecialFileFD can only be enabled by specifying
// "cache=none".
@@ -459,44 +482,83 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
syncableDentries: make(map[*dentry]struct{}),
specialFileFDs: make(map[*specialFileFD]struct{}),
inoByQIDPath: make(map[uint64]uint64),
+ inoByKey: make(map[inoKey]uint64),
}
fs.vfsfs.Init(vfsObj, &fstype, fs)
+ if err := fs.initClientAndRoot(ctx); err != nil {
+ fs.vfsfs.DecRef(ctx)
+ return nil, nil, err
+ }
+
+ return &fs.vfsfs, &fs.root.vfsd, nil
+}
+
+func (fs *filesystem) initClientAndRoot(ctx context.Context) error {
+ var err error
+ if fs.opts.lisaEnabled {
+ var rootInode *lisafs.Inode
+ rootInode, err = fs.initClientLisa(ctx)
+ if err != nil {
+ return err
+ }
+ fs.root, err = fs.newDentryLisa(ctx, rootInode)
+ if err != nil {
+ fs.clientLisa.CloseFDBatched(ctx, rootInode.ControlFD)
+ }
+ } else {
+ fs.root, err = fs.initClient(ctx)
+ }
+
+ // Set the root's reference count to 2. One reference is returned to the
+ // caller, and the other is held by fs to prevent the root from being "cached"
+ // and subsequently evicted.
+ if err == nil {
+ fs.root.refs = 2
+ }
+ return err
+}
+
+func (fs *filesystem) initClientLisa(ctx context.Context) (*lisafs.Inode, error) {
+ sock, err := unet.NewSocket(fs.opts.fd)
+ if err != nil {
+ return nil, err
+ }
+
+ var rootInode *lisafs.Inode
+ ctx.UninterruptibleSleepStart(false)
+ fs.clientLisa, rootInode, err = lisafs.NewClient(sock, fs.opts.aname)
+ ctx.UninterruptibleSleepFinish(false)
+ return rootInode, err
+}
+
+func (fs *filesystem) initClient(ctx context.Context) (*dentry, error) {
// Connect to the server.
if err := fs.dial(ctx); err != nil {
- return nil, nil, err
+ return nil, err
}
// Perform attach to obtain the filesystem root.
ctx.UninterruptibleSleepStart(false)
- attached, err := fs.client.Attach(fsopts.aname)
+ attached, err := fs.client.Attach(fs.opts.aname)
ctx.UninterruptibleSleepFinish(false)
if err != nil {
- fs.vfsfs.DecRef(ctx)
- return nil, nil, err
+ return nil, err
}
attachFile := p9file{attached}
qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask())
if err != nil {
attachFile.close(ctx)
- fs.vfsfs.DecRef(ctx)
- return nil, nil, err
+ return nil, err
}
// Construct the root dentry.
root, err := fs.newDentry(ctx, attachFile, qid, attrMask, &attr)
if err != nil {
attachFile.close(ctx)
- fs.vfsfs.DecRef(ctx)
- return nil, nil, err
+ return nil, err
}
- // Set the root's reference count to 2. One reference is returned to the
- // caller, and the other is held by fs to prevent the root from being "cached"
- // and subsequently evicted.
- root.refs = 2
- fs.root = root
-
- return &fs.vfsfs, &root.vfsd, nil
+ return root, nil
}
func getFDFromMountOptionsMap(ctx context.Context, mopts map[string]string) (int, error) {
@@ -614,7 +676,11 @@ func (fs *filesystem) Release(ctx context.Context) {
if !fs.iopts.LeakConnection {
// Close the connection to the server. This implicitly clunks all fids.
- fs.client.Close()
+ if fs.opts.lisaEnabled {
+ fs.clientLisa.Close()
+ } else {
+ fs.client.Close()
+ }
}
fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
@@ -645,6 +711,23 @@ func (d *dentry) releaseSyntheticRecursiveLocked(ctx context.Context) {
}
}
+// inoKey is the key used to identify the inode backed by this dentry.
+//
+// +stateify savable
+type inoKey struct {
+ ino uint64
+ devMinor uint32
+ devMajor uint32
+}
+
+func inoKeyFromStat(stat *linux.Statx) inoKey {
+ return inoKey{
+ ino: stat.Ino,
+ devMinor: stat.DevMinor,
+ devMajor: stat.DevMajor,
+ }
+}
+
// dentry implements vfs.DentryImpl.
//
// +stateify savable
@@ -675,6 +758,9 @@ type dentry struct {
// qidPath is the p9.QID.Path for this file. qidPath is immutable.
qidPath uint64
+ // inoKey is used to identify this dentry's inode.
+ inoKey inoKey
+
// file is the unopened p9.File that backs this dentry. file is immutable.
//
// If file.isNil(), this dentry represents a synthetic file, i.e. a file
@@ -682,6 +768,14 @@ type dentry struct {
// only files that can be synthetic are sockets, pipes, and directories.
file p9file `state:"nosave"`
+ // controlFDLisa is used by lisafs to perform path based operations on this
+ // dentry.
+ //
+ // if !controlFDLisa.Ok(), this dentry represents a synthetic file, i.e. a
+ // file that does not exist on the remote filesystem. As of this writing, the
+ // only files that can be synthetic are sockets, pipes, and directories.
+ controlFDLisa lisafs.ClientFD `state:"nosave"`
+
// If deleted is non-zero, the file represented by this dentry has been
// deleted. deleted is accessed using atomic memory operations.
deleted uint32
@@ -792,12 +886,14 @@ type dentry struct {
// always either -1 or equal to readFD; if !writeFile.isNil() (the file has
// been opened for writing), it is additionally either -1 or equal to
// writeFD.
- handleMu sync.RWMutex `state:"nosave"`
- readFile p9file `state:"nosave"`
- writeFile p9file `state:"nosave"`
- readFD int32 `state:"nosave"`
- writeFD int32 `state:"nosave"`
- mmapFD int32 `state:"nosave"`
+ handleMu sync.RWMutex `state:"nosave"`
+ readFile p9file `state:"nosave"`
+ writeFile p9file `state:"nosave"`
+ readFDLisa lisafs.ClientFD `state:"nosave"`
+ writeFDLisa lisafs.ClientFD `state:"nosave"`
+ readFD int32 `state:"nosave"`
+ writeFD int32 `state:"nosave"`
+ mmapFD int32 `state:"nosave"`
dataMu sync.RWMutex `state:"nosave"`
@@ -865,11 +961,11 @@ func dentryAttrMask() p9.AttrMask {
func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, mask p9.AttrMask, attr *p9.Attr) (*dentry, error) {
if !mask.Mode {
ctx.Warningf("can't create gofer.dentry without file type")
- return nil, syserror.EIO
+ return nil, linuxerr.EIO
}
if attr.Mode.FileType() == p9.ModeRegular && !mask.Size {
ctx.Warningf("can't create regular file gofer.dentry without file size")
- return nil, syserror.EIO
+ return nil, linuxerr.EIO
}
d := &dentry{
@@ -921,6 +1017,79 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma
return d, nil
}
+func (fs *filesystem) newDentryLisa(ctx context.Context, ino *lisafs.Inode) (*dentry, error) {
+ if ino.Stat.Mask&linux.STATX_TYPE == 0 {
+ ctx.Warningf("can't create gofer.dentry without file type")
+ return nil, linuxerr.EIO
+ }
+ if ino.Stat.Mode&linux.FileTypeMask == linux.ModeRegular && ino.Stat.Mask&linux.STATX_SIZE == 0 {
+ ctx.Warningf("can't create regular file gofer.dentry without file size")
+ return nil, linuxerr.EIO
+ }
+
+ inoKey := inoKeyFromStat(&ino.Stat)
+ d := &dentry{
+ fs: fs,
+ inoKey: inoKey,
+ ino: fs.inoFromKey(inoKey),
+ mode: uint32(ino.Stat.Mode),
+ uid: uint32(fs.opts.dfltuid),
+ gid: uint32(fs.opts.dfltgid),
+ blockSize: hostarch.PageSize,
+ readFD: -1,
+ writeFD: -1,
+ mmapFD: -1,
+ controlFDLisa: fs.clientLisa.NewFD(ino.ControlFD),
+ }
+
+ d.pf.dentry = d
+ if ino.Stat.Mask&linux.STATX_UID != 0 {
+ d.uid = dentryUIDFromLisaUID(lisafs.UID(ino.Stat.UID))
+ }
+ if ino.Stat.Mask&linux.STATX_GID != 0 {
+ d.gid = dentryGIDFromLisaGID(lisafs.GID(ino.Stat.GID))
+ }
+ if ino.Stat.Mask&linux.STATX_SIZE != 0 {
+ d.size = ino.Stat.Size
+ }
+ if ino.Stat.Blksize != 0 {
+ d.blockSize = ino.Stat.Blksize
+ }
+ if ino.Stat.Mask&linux.STATX_ATIME != 0 {
+ d.atime = dentryTimestampFromLisa(ino.Stat.Atime)
+ }
+ if ino.Stat.Mask&linux.STATX_MTIME != 0 {
+ d.mtime = dentryTimestampFromLisa(ino.Stat.Mtime)
+ }
+ if ino.Stat.Mask&linux.STATX_CTIME != 0 {
+ d.ctime = dentryTimestampFromLisa(ino.Stat.Ctime)
+ }
+ if ino.Stat.Mask&linux.STATX_BTIME != 0 {
+ d.btime = dentryTimestampFromLisa(ino.Stat.Btime)
+ }
+ if ino.Stat.Mask&linux.STATX_NLINK != 0 {
+ d.nlink = ino.Stat.Nlink
+ }
+ d.vfsd.Init(d)
+ refsvfs2.Register(d)
+ fs.syncMu.Lock()
+ fs.syncableDentries[d] = struct{}{}
+ fs.syncMu.Unlock()
+ return d, nil
+}
+
+func (fs *filesystem) inoFromKey(key inoKey) uint64 {
+ fs.inoMu.Lock()
+ defer fs.inoMu.Unlock()
+
+ if ino, ok := fs.inoByKey[key]; ok {
+ return ino
+ }
+ ino := fs.nextIno()
+ fs.inoByKey[key] = ino
+ return ino
+}
+
func (fs *filesystem) inoFromQIDPath(qidPath uint64) uint64 {
fs.inoMu.Lock()
defer fs.inoMu.Unlock()
@@ -937,7 +1106,7 @@ func (fs *filesystem) nextIno() uint64 {
}
func (d *dentry) isSynthetic() bool {
- return d.file.isNil()
+ return !d.isControlFileOk()
}
func (d *dentry) cachedMetadataAuthoritative() bool {
@@ -987,6 +1156,50 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) {
}
}
+// updateFromLisaStatLocked is called to update d's metadata after an update
+// from the remote filesystem.
+// Precondition: d.metadataMu must be locked.
+// +checklocks:d.metadataMu
+func (d *dentry) updateFromLisaStatLocked(stat *linux.Statx) {
+ if stat.Mask&linux.STATX_TYPE != 0 {
+ if got, want := stat.Mode&linux.FileTypeMask, d.fileType(); uint32(got) != want {
+ panic(fmt.Sprintf("gofer.dentry file type changed from %#o to %#o", want, got))
+ }
+ }
+ if stat.Mask&linux.STATX_MODE != 0 {
+ atomic.StoreUint32(&d.mode, uint32(stat.Mode))
+ }
+ if stat.Mask&linux.STATX_UID != 0 {
+ atomic.StoreUint32(&d.uid, dentryUIDFromLisaUID(lisafs.UID(stat.UID)))
+ }
+ if stat.Mask&linux.STATX_GID != 0 {
+ atomic.StoreUint32(&d.uid, dentryGIDFromLisaGID(lisafs.GID(stat.GID)))
+ }
+ if stat.Blksize != 0 {
+ atomic.StoreUint32(&d.blockSize, stat.Blksize)
+ }
+ // Don't override newer client-defined timestamps with old server-defined
+ // ones.
+ if stat.Mask&linux.STATX_ATIME != 0 && atomic.LoadUint32(&d.atimeDirty) == 0 {
+ atomic.StoreInt64(&d.atime, dentryTimestampFromLisa(stat.Atime))
+ }
+ if stat.Mask&linux.STATX_MTIME != 0 && atomic.LoadUint32(&d.mtimeDirty) == 0 {
+ atomic.StoreInt64(&d.mtime, dentryTimestampFromLisa(stat.Mtime))
+ }
+ if stat.Mask&linux.STATX_CTIME != 0 {
+ atomic.StoreInt64(&d.ctime, dentryTimestampFromLisa(stat.Ctime))
+ }
+ if stat.Mask&linux.STATX_BTIME != 0 {
+ atomic.StoreInt64(&d.btime, dentryTimestampFromLisa(stat.Btime))
+ }
+ if stat.Mask&linux.STATX_NLINK != 0 {
+ atomic.StoreUint32(&d.nlink, stat.Nlink)
+ }
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ d.updateSizeLocked(stat.Size)
+ }
+}
+
// Preconditions: !d.isSynthetic().
// Preconditions: d.metadataMu is locked.
// +checklocks:d.metadataMu
@@ -996,7 +1209,10 @@ func (d *dentry) refreshSizeLocked(ctx context.Context) error {
if d.writeFD < 0 {
d.handleMu.RUnlock()
// Ask the gofer if we don't have a host FD.
- return d.updateFromGetattrLocked(ctx)
+ if d.fs.opts.lisaEnabled {
+ return d.updateFromStatLisaLocked(ctx, nil)
+ }
+ return d.updateFromGetattrLocked(ctx, p9file{})
}
var stat unix.Statx_t
@@ -1015,33 +1231,77 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error {
// updating stale attributes in d.updateFromP9AttrsLocked().
d.metadataMu.Lock()
defer d.metadataMu.Unlock()
- return d.updateFromGetattrLocked(ctx)
+ if d.fs.opts.lisaEnabled {
+ return d.updateFromStatLisaLocked(ctx, nil)
+ }
+ return d.updateFromGetattrLocked(ctx, p9file{})
}
// Preconditions:
// * !d.isSynthetic().
// * d.metadataMu is locked.
// +checklocks:d.metadataMu
-func (d *dentry) updateFromGetattrLocked(ctx context.Context) error {
- // Use d.readFile or d.writeFile, which represent 9P FIDs that have been
- // opened, in preference to d.file, which represents a 9P fid that has not.
- // This may be significantly more efficient in some implementations. Prefer
- // d.writeFile over d.readFile since some filesystem implementations may
- // update a writable handle's metadata after writes to that handle, without
- // making metadata updates immediately visible to read-only handles
- // representing the same file.
- d.handleMu.RLock()
- handleMuRLocked := true
- var file p9file
- switch {
- case !d.writeFile.isNil():
- file = d.writeFile
- case !d.readFile.isNil():
- file = d.readFile
- default:
- file = d.file
- d.handleMu.RUnlock()
- handleMuRLocked = false
+func (d *dentry) updateFromStatLisaLocked(ctx context.Context, fdLisa *lisafs.ClientFD) error {
+ handleMuRLocked := false
+ if fdLisa == nil {
+ // Use open FDs in preferenece to the control FD. This may be significantly
+ // more efficient in some implementations. Prefer a writable FD over a
+ // readable one since some filesystem implementations may update a writable
+ // FD's metadata after writes, without making metadata updates immediately
+ // visible to read-only FDs representing the same file.
+ d.handleMu.RLock()
+ switch {
+ case d.writeFDLisa.Ok():
+ fdLisa = &d.writeFDLisa
+ handleMuRLocked = true
+ case d.readFDLisa.Ok():
+ fdLisa = &d.readFDLisa
+ handleMuRLocked = true
+ default:
+ fdLisa = &d.controlFDLisa
+ d.handleMu.RUnlock()
+ }
+ }
+
+ var stat linux.Statx
+ err := fdLisa.StatTo(ctx, &stat)
+ if handleMuRLocked {
+ // handleMu must be released before updateFromLisaStatLocked().
+ d.handleMu.RUnlock() // +checklocksforce: complex case.
+ }
+ if err != nil {
+ return err
+ }
+ d.updateFromLisaStatLocked(&stat)
+ return nil
+}
+
+// Preconditions:
+// * !d.isSynthetic().
+// * d.metadataMu is locked.
+// +checklocks:d.metadataMu
+func (d *dentry) updateFromGetattrLocked(ctx context.Context, file p9file) error {
+ handleMuRLocked := false
+ if file.isNil() {
+ // Use d.readFile or d.writeFile, which represent 9P FIDs that have
+ // been opened, in preference to d.file, which represents a 9P fid that
+ // has not. This may be significantly more efficient in some
+ // implementations. Prefer d.writeFile over d.readFile since some
+ // filesystem implementations may update a writable handle's metadata
+ // after writes to that handle, without making metadata updates
+ // immediately visible to read-only handles representing the same file.
+ d.handleMu.RLock()
+ switch {
+ case !d.writeFile.isNil():
+ file = d.writeFile
+ handleMuRLocked = true
+ case !d.readFile.isNil():
+ file = d.readFile
+ handleMuRLocked = true
+ default:
+ file = d.file
+ d.handleMu.RUnlock()
+ }
}
_, attrMask, attr, err := file.getAttr(ctx, dentryAttrMask())
@@ -1112,7 +1372,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
case linux.S_IFREG:
// ok
case linux.S_IFDIR:
- return syserror.EISDIR
+ return linuxerr.EISDIR
default:
return linuxerr.EINVAL
}
@@ -1159,6 +1419,13 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
}
}
+ // failureMask indicates which attributes could not be set on the remote
+ // filesystem. p9 returns an error if any of the attributes could not be set
+ // but that leads to inconsistency as the server could have set a few
+ // attributes successfully but a later failure will cause the successful ones
+ // to not be updated in the dentry cache.
+ var failureMask uint32
+ var failureErr error
if !d.isSynthetic() {
if stat.Mask != 0 {
if stat.Mask&linux.STATX_SIZE != 0 {
@@ -1168,35 +1435,50 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
// the remote file has been truncated).
d.dataMu.Lock()
}
- if err := d.file.setAttr(ctx, p9.SetAttrMask{
- Permissions: stat.Mask&linux.STATX_MODE != 0,
- UID: stat.Mask&linux.STATX_UID != 0,
- GID: stat.Mask&linux.STATX_GID != 0,
- Size: stat.Mask&linux.STATX_SIZE != 0,
- ATime: stat.Mask&linux.STATX_ATIME != 0,
- MTime: stat.Mask&linux.STATX_MTIME != 0,
- ATimeNotSystemTime: stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW,
- MTimeNotSystemTime: stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec != linux.UTIME_NOW,
- }, p9.SetAttr{
- Permissions: p9.FileMode(stat.Mode),
- UID: p9.UID(stat.UID),
- GID: p9.GID(stat.GID),
- Size: stat.Size,
- ATimeSeconds: uint64(stat.Atime.Sec),
- ATimeNanoSeconds: uint64(stat.Atime.Nsec),
- MTimeSeconds: uint64(stat.Mtime.Sec),
- MTimeNanoSeconds: uint64(stat.Mtime.Nsec),
- }); err != nil {
- if stat.Mask&linux.STATX_SIZE != 0 {
- d.dataMu.Unlock() // +checklocksforce: locked conditionally above
+ if d.fs.opts.lisaEnabled {
+ var err error
+ failureMask, failureErr, err = d.controlFDLisa.SetStat(ctx, stat)
+ if err != nil {
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ d.dataMu.Unlock() // +checklocksforce: locked conditionally above
+ }
+ return err
+ }
+ } else {
+ if err := d.file.setAttr(ctx, p9.SetAttrMask{
+ Permissions: stat.Mask&linux.STATX_MODE != 0,
+ UID: stat.Mask&linux.STATX_UID != 0,
+ GID: stat.Mask&linux.STATX_GID != 0,
+ Size: stat.Mask&linux.STATX_SIZE != 0,
+ ATime: stat.Mask&linux.STATX_ATIME != 0,
+ MTime: stat.Mask&linux.STATX_MTIME != 0,
+ ATimeNotSystemTime: stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW,
+ MTimeNotSystemTime: stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec != linux.UTIME_NOW,
+ }, p9.SetAttr{
+ Permissions: p9.FileMode(stat.Mode),
+ UID: p9.UID(stat.UID),
+ GID: p9.GID(stat.GID),
+ Size: stat.Size,
+ ATimeSeconds: uint64(stat.Atime.Sec),
+ ATimeNanoSeconds: uint64(stat.Atime.Nsec),
+ MTimeSeconds: uint64(stat.Mtime.Sec),
+ MTimeNanoSeconds: uint64(stat.Mtime.Nsec),
+ }); err != nil {
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ d.dataMu.Unlock() // +checklocksforce: locked conditionally above
+ }
+ return err
}
- return err
}
if stat.Mask&linux.STATX_SIZE != 0 {
- // d.size should be kept up to date, and privatized
- // copy-on-write mappings of truncated pages need to be
- // invalidated, even if InteropModeShared is in effect.
- d.updateSizeAndUnlockDataMuLocked(stat.Size) // +checklocksforce: locked conditionally above
+ if failureMask&linux.STATX_SIZE == 0 {
+ // d.size should be kept up to date, and privatized
+ // copy-on-write mappings of truncated pages need to be
+ // invalidated, even if InteropModeShared is in effect.
+ d.updateSizeAndUnlockDataMuLocked(stat.Size) // +checklocksforce: locked conditionally above
+ } else {
+ d.dataMu.Unlock() // +checklocksforce: locked conditionally above
+ }
}
}
if d.fs.opts.interop == InteropModeShared {
@@ -1207,13 +1489,13 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
return nil
}
}
- if stat.Mask&linux.STATX_MODE != 0 {
+ if stat.Mask&linux.STATX_MODE != 0 && failureMask&linux.STATX_MODE == 0 {
atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode))
}
- if stat.Mask&linux.STATX_UID != 0 {
+ if stat.Mask&linux.STATX_UID != 0 && failureMask&linux.STATX_UID == 0 {
atomic.StoreUint32(&d.uid, stat.UID)
}
- if stat.Mask&linux.STATX_GID != 0 {
+ if stat.Mask&linux.STATX_GID != 0 && failureMask&linux.STATX_GID == 0 {
atomic.StoreUint32(&d.gid, stat.GID)
}
// Note that stat.Atime.Nsec and stat.Mtime.Nsec can't be UTIME_NOW because
@@ -1221,15 +1503,19 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
// stat.Mtime to client-local timestamps above, and if
// !d.cachedMetadataAuthoritative() then we returned after calling
// d.file.setAttr(). For the same reason, now must have been initialized.
- if stat.Mask&linux.STATX_ATIME != 0 {
+ if stat.Mask&linux.STATX_ATIME != 0 && failureMask&linux.STATX_ATIME == 0 {
atomic.StoreInt64(&d.atime, stat.Atime.ToNsec())
atomic.StoreUint32(&d.atimeDirty, 0)
}
- if stat.Mask&linux.STATX_MTIME != 0 {
+ if stat.Mask&linux.STATX_MTIME != 0 && failureMask&linux.STATX_MTIME == 0 {
atomic.StoreInt64(&d.mtime, stat.Mtime.ToNsec())
atomic.StoreUint32(&d.mtimeDirty, 0)
}
atomic.StoreInt64(&d.ctime, now)
+ if failureMask != 0 {
+ // Setting some attribute failed on the remote filesystem.
+ return failureErr
+ }
return nil
}
@@ -1309,7 +1595,10 @@ func (d *dentry) checkXattrPermissions(creds *auth.Credentials, name string, ats
// (b/148380782). Allow all other extended attributes to be passed through
// to the remote filesystem. This is inconsistent with Linux's 9p client,
// but consistent with other filesystems (e.g. FUSE).
- if strings.HasPrefix(name, linux.XATTR_SECURITY_PREFIX) || strings.HasPrefix(name, linux.XATTR_SYSTEM_PREFIX) {
+ //
+ // NOTE(b/202533394): Also disallow "trusted" namespace for now. This is
+ // consistent with the VFS1 gofer client.
+ if strings.HasPrefix(name, linux.XATTR_SECURITY_PREFIX) || strings.HasPrefix(name, linux.XATTR_SYSTEM_PREFIX) || strings.HasPrefix(name, linux.XATTR_TRUSTED_PREFIX) {
return linuxerr.EOPNOTSUPP
}
mode := linux.FileMode(atomic.LoadUint32(&d.mode))
@@ -1345,6 +1634,20 @@ func dentryGIDFromP9GID(gid p9.GID) uint32 {
return uint32(gid)
}
+func dentryUIDFromLisaUID(uid lisafs.UID) uint32 {
+ if !uid.Ok() {
+ return uint32(auth.OverflowUID)
+ }
+ return uint32(uid)
+}
+
+func dentryGIDFromLisaGID(gid lisafs.GID) uint32 {
+ if !gid.Ok() {
+ return uint32(auth.OverflowGID)
+ }
+ return uint32(gid)
+}
+
// IncRef implements vfs.DentryImpl.IncRef.
func (d *dentry) IncRef() {
// d.refs may be 0 if d.fs.renameMu is locked, which serializes against
@@ -1653,15 +1956,24 @@ func (d *dentry) destroyLocked(ctx context.Context) {
d.dirty.RemoveAll()
}
d.dataMu.Unlock()
- // Clunk open fids and close open host FDs.
- if !d.readFile.isNil() {
- _ = d.readFile.close(ctx)
- }
- if !d.writeFile.isNil() && d.readFile != d.writeFile {
- _ = d.writeFile.close(ctx)
+ if d.fs.opts.lisaEnabled {
+ if d.readFDLisa.Ok() && d.readFDLisa.ID() != d.writeFDLisa.ID() {
+ d.readFDLisa.CloseBatched(ctx)
+ }
+ if d.writeFDLisa.Ok() {
+ d.writeFDLisa.CloseBatched(ctx)
+ }
+ } else {
+ // Clunk open fids and close open host FDs.
+ if !d.readFile.isNil() {
+ _ = d.readFile.close(ctx)
+ }
+ if !d.writeFile.isNil() && d.readFile != d.writeFile {
+ _ = d.writeFile.close(ctx)
+ }
+ d.readFile = p9file{}
+ d.writeFile = p9file{}
}
- d.readFile = p9file{}
- d.writeFile = p9file{}
if d.readFD >= 0 {
_ = unix.Close(int(d.readFD))
}
@@ -1673,7 +1985,7 @@ func (d *dentry) destroyLocked(ctx context.Context) {
d.mmapFD = -1
d.handleMu.Unlock()
- if !d.file.isNil() {
+ if d.isControlFileOk() {
// Note that it's possible that d.atimeDirty or d.mtimeDirty are true,
// i.e. client and server timestamps may differ (because e.g. a client
// write was serviced by the page cache, and only written back to the
@@ -1682,10 +1994,16 @@ func (d *dentry) destroyLocked(ctx context.Context) {
// instantiated for the same file would remain coherent. Unfortunately,
// this turns out to be too expensive in many cases, so for now we
// don't do this.
- if err := d.file.close(ctx); err != nil {
- log.Warningf("gofer.dentry.destroyLocked: failed to close file: %v", err)
+
+ // Close the control FD.
+ if d.fs.opts.lisaEnabled {
+ d.controlFDLisa.CloseBatched(ctx)
+ } else {
+ if err := d.file.close(ctx); err != nil {
+ log.Warningf("gofer.dentry.destroyLocked: failed to close file: %v", err)
+ }
+ d.file = p9file{}
}
- d.file = p9file{}
// Remove d from the set of syncable dentries.
d.fs.syncMu.Lock()
@@ -1711,10 +2029,29 @@ func (d *dentry) setDeleted() {
atomic.StoreUint32(&d.deleted, 1)
}
-func (d *dentry) listXattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) {
- if d.file.isNil() {
+func (d *dentry) isControlFileOk() bool {
+ if d.fs.opts.lisaEnabled {
+ return d.controlFDLisa.Ok()
+ }
+ return !d.file.isNil()
+}
+
+func (d *dentry) isReadFileOk() bool {
+ if d.fs.opts.lisaEnabled {
+ return d.readFDLisa.Ok()
+ }
+ return !d.readFile.isNil()
+}
+
+func (d *dentry) listXattr(ctx context.Context, size uint64) ([]string, error) {
+ if !d.isControlFileOk() {
return nil, nil
}
+
+ if d.fs.opts.lisaEnabled {
+ return d.controlFDLisa.ListXattr(ctx, size)
+ }
+
xattrMap, err := d.file.listXattr(ctx, size)
if err != nil {
return nil, err
@@ -1727,32 +2064,41 @@ func (d *dentry) listXattr(ctx context.Context, creds *auth.Credentials, size ui
}
func (d *dentry) getXattr(ctx context.Context, creds *auth.Credentials, opts *vfs.GetXattrOptions) (string, error) {
- if d.file.isNil() {
+ if !d.isControlFileOk() {
return "", linuxerr.ENODATA
}
if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayRead); err != nil {
return "", err
}
+ if d.fs.opts.lisaEnabled {
+ return d.controlFDLisa.GetXattr(ctx, opts.Name, opts.Size)
+ }
return d.file.getXattr(ctx, opts.Name, opts.Size)
}
func (d *dentry) setXattr(ctx context.Context, creds *auth.Credentials, opts *vfs.SetXattrOptions) error {
- if d.file.isNil() {
+ if !d.isControlFileOk() {
return linuxerr.EPERM
}
if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayWrite); err != nil {
return err
}
+ if d.fs.opts.lisaEnabled {
+ return d.controlFDLisa.SetXattr(ctx, opts.Name, opts.Value, opts.Flags)
+ }
return d.file.setXattr(ctx, opts.Name, opts.Value, opts.Flags)
}
func (d *dentry) removeXattr(ctx context.Context, creds *auth.Credentials, name string) error {
- if d.file.isNil() {
+ if !d.isControlFileOk() {
return linuxerr.EPERM
}
if err := d.checkXattrPermissions(creds, name, vfs.MayWrite); err != nil {
return err
}
+ if d.fs.opts.lisaEnabled {
+ return d.controlFDLisa.RemoveXattr(ctx, name)
+ }
return d.file.removeXattr(ctx, name)
}
@@ -1764,19 +2110,30 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
// O_TRUNC).
if !trunc {
d.handleMu.RLock()
- if (!read || !d.readFile.isNil()) && (!write || !d.writeFile.isNil()) {
+ var canReuseCurHandle bool
+ if d.fs.opts.lisaEnabled {
+ canReuseCurHandle = (!read || d.readFDLisa.Ok()) && (!write || d.writeFDLisa.Ok())
+ } else {
+ canReuseCurHandle = (!read || !d.readFile.isNil()) && (!write || !d.writeFile.isNil())
+ }
+ d.handleMu.RUnlock()
+ if canReuseCurHandle {
// Current handles are sufficient.
- d.handleMu.RUnlock()
return nil
}
- d.handleMu.RUnlock()
}
var fdsToCloseArr [2]int32
fdsToClose := fdsToCloseArr[:0]
invalidateTranslations := false
d.handleMu.Lock()
- if (read && d.readFile.isNil()) || (write && d.writeFile.isNil()) || trunc {
+ var needNewHandle bool
+ if d.fs.opts.lisaEnabled {
+ needNewHandle = (read && !d.readFDLisa.Ok()) || (write && !d.writeFDLisa.Ok()) || trunc
+ } else {
+ needNewHandle = (read && d.readFile.isNil()) || (write && d.writeFile.isNil()) || trunc
+ }
+ if needNewHandle {
// Get a new handle. If this file has been opened for both reading and
// writing, try to get a single handle that is usable for both:
//
@@ -1785,9 +2142,21 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
//
// - NOTE(b/141991141): Some filesystems may not ensure coherence
// between multiple handles for the same file.
- openReadable := !d.readFile.isNil() || read
- openWritable := !d.writeFile.isNil() || write
- h, err := openHandle(ctx, d.file, openReadable, openWritable, trunc)
+ var (
+ openReadable bool
+ openWritable bool
+ h handle
+ err error
+ )
+ if d.fs.opts.lisaEnabled {
+ openReadable = d.readFDLisa.Ok() || read
+ openWritable = d.writeFDLisa.Ok() || write
+ h, err = openHandleLisa(ctx, d.controlFDLisa, openReadable, openWritable, trunc)
+ } else {
+ openReadable = !d.readFile.isNil() || read
+ openWritable = !d.writeFile.isNil() || write
+ h, err = openHandle(ctx, d.file, openReadable, openWritable, trunc)
+ }
if linuxerr.Equals(linuxerr.EACCES, err) && (openReadable != read || openWritable != write) {
// It may not be possible to use a single handle for both
// reading and writing, since permissions on the file may have
@@ -1797,7 +2166,11 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
ctx.Debugf("gofer.dentry.ensureSharedHandle: bifurcating read/write handles for dentry %p", d)
openReadable = read
openWritable = write
- h, err = openHandle(ctx, d.file, openReadable, openWritable, trunc)
+ if d.fs.opts.lisaEnabled {
+ h, err = openHandleLisa(ctx, d.controlFDLisa, openReadable, openWritable, trunc)
+ } else {
+ h, err = openHandle(ctx, d.file, openReadable, openWritable, trunc)
+ }
}
if err != nil {
d.handleMu.Unlock()
@@ -1859,9 +2232,16 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
// previously opened for reading (without an FD), then existing
// translations of the file may use the internal page cache;
// invalidate those mappings.
- if d.writeFile.isNil() {
- invalidateTranslations = !d.readFile.isNil()
- atomic.StoreInt32(&d.mmapFD, h.fd)
+ if d.fs.opts.lisaEnabled {
+ if !d.writeFDLisa.Ok() {
+ invalidateTranslations = d.readFDLisa.Ok()
+ atomic.StoreInt32(&d.mmapFD, h.fd)
+ }
+ } else {
+ if d.writeFile.isNil() {
+ invalidateTranslations = !d.readFile.isNil()
+ atomic.StoreInt32(&d.mmapFD, h.fd)
+ }
}
} else if openWritable && d.writeFD < 0 {
atomic.StoreInt32(&d.writeFD, h.fd)
@@ -1888,24 +2268,45 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
atomic.StoreInt32(&d.mmapFD, -1)
}
- // Switch to new fids.
- var oldReadFile p9file
- if openReadable {
- oldReadFile = d.readFile
- d.readFile = h.file
- }
- var oldWriteFile p9file
- if openWritable {
- oldWriteFile = d.writeFile
- d.writeFile = h.file
- }
- // NOTE(b/141991141): Clunk old fids before making new fids visible (by
- // unlocking d.handleMu).
- if !oldReadFile.isNil() {
- oldReadFile.close(ctx)
- }
- if !oldWriteFile.isNil() && oldReadFile != oldWriteFile {
- oldWriteFile.close(ctx)
+ // Switch to new fids/FDs.
+ if d.fs.opts.lisaEnabled {
+ oldReadFD := lisafs.InvalidFDID
+ if openReadable {
+ oldReadFD = d.readFDLisa.ID()
+ d.readFDLisa = h.fdLisa
+ }
+ oldWriteFD := lisafs.InvalidFDID
+ if openWritable {
+ oldWriteFD = d.writeFDLisa.ID()
+ d.writeFDLisa = h.fdLisa
+ }
+ // NOTE(b/141991141): Close old FDs before making new fids visible (by
+ // unlocking d.handleMu).
+ if oldReadFD.Ok() {
+ d.fs.clientLisa.CloseFDBatched(ctx, oldReadFD)
+ }
+ if oldWriteFD.Ok() && oldReadFD != oldWriteFD {
+ d.fs.clientLisa.CloseFDBatched(ctx, oldWriteFD)
+ }
+ } else {
+ var oldReadFile p9file
+ if openReadable {
+ oldReadFile = d.readFile
+ d.readFile = h.file
+ }
+ var oldWriteFile p9file
+ if openWritable {
+ oldWriteFile = d.writeFile
+ d.writeFile = h.file
+ }
+ // NOTE(b/141991141): Clunk old fids before making new fids visible (by
+ // unlocking d.handleMu).
+ if !oldReadFile.isNil() {
+ oldReadFile.close(ctx)
+ }
+ if !oldWriteFile.isNil() && oldReadFile != oldWriteFile {
+ oldWriteFile.close(ctx)
+ }
}
}
d.handleMu.Unlock()
@@ -1929,27 +2330,29 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
// Preconditions: d.handleMu must be locked.
func (d *dentry) readHandleLocked() handle {
return handle{
- file: d.readFile,
- fd: d.readFD,
+ fdLisa: d.readFDLisa,
+ file: d.readFile,
+ fd: d.readFD,
}
}
// Preconditions: d.handleMu must be locked.
func (d *dentry) writeHandleLocked() handle {
return handle{
- file: d.writeFile,
- fd: d.writeFD,
+ fdLisa: d.writeFDLisa,
+ file: d.writeFile,
+ fd: d.writeFD,
}
}
func (d *dentry) syncRemoteFile(ctx context.Context) error {
d.handleMu.RLock()
defer d.handleMu.RUnlock()
- return d.syncRemoteFileLocked(ctx)
+ return d.syncRemoteFileLocked(ctx, nil /* accFsyncFDIDsLisa */)
}
// Preconditions: d.handleMu must be locked.
-func (d *dentry) syncRemoteFileLocked(ctx context.Context) error {
+func (d *dentry) syncRemoteFileLocked(ctx context.Context, accFsyncFDIDsLisa *[]lisafs.FDID) error {
// If we have a host FD, fsyncing it is likely to be faster than an fsync
// RPC. Prefer syncing write handles over read handles, since some remote
// filesystem implementations may not sync changes made through write
@@ -1960,7 +2363,13 @@ func (d *dentry) syncRemoteFileLocked(ctx context.Context) error {
ctx.UninterruptibleSleepFinish(false)
return err
}
- if !d.writeFile.isNil() {
+ if d.fs.opts.lisaEnabled && d.writeFDLisa.Ok() {
+ if accFsyncFDIDsLisa != nil {
+ *accFsyncFDIDsLisa = append(*accFsyncFDIDsLisa, d.writeFDLisa.ID())
+ return nil
+ }
+ return d.writeFDLisa.Sync(ctx)
+ } else if !d.fs.opts.lisaEnabled && !d.writeFile.isNil() {
return d.writeFile.fsync(ctx)
}
if d.readFD >= 0 {
@@ -1969,13 +2378,19 @@ func (d *dentry) syncRemoteFileLocked(ctx context.Context) error {
ctx.UninterruptibleSleepFinish(false)
return err
}
- if !d.readFile.isNil() {
+ if d.fs.opts.lisaEnabled && d.readFDLisa.Ok() {
+ if accFsyncFDIDsLisa != nil {
+ *accFsyncFDIDsLisa = append(*accFsyncFDIDsLisa, d.readFDLisa.ID())
+ return nil
+ }
+ return d.readFDLisa.Sync(ctx)
+ } else if !d.fs.opts.lisaEnabled && !d.readFile.isNil() {
return d.readFile.fsync(ctx)
}
return nil
}
-func (d *dentry) syncCachedFile(ctx context.Context, forFilesystemSync bool) error {
+func (d *dentry) syncCachedFile(ctx context.Context, forFilesystemSync bool, accFsyncFDIDsLisa *[]lisafs.FDID) error {
d.handleMu.RLock()
defer d.handleMu.RUnlock()
h := d.writeHandleLocked()
@@ -1988,7 +2403,7 @@ func (d *dentry) syncCachedFile(ctx context.Context, forFilesystemSync bool) err
return err
}
}
- if err := d.syncRemoteFileLocked(ctx); err != nil {
+ if err := d.syncRemoteFileLocked(ctx, accFsyncFDIDsLisa); err != nil {
if !forFilesystemSync {
return err
}
@@ -2045,10 +2460,33 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu
d := fd.dentry()
const validMask = uint32(linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_SIZE | linux.STATX_BLOCKS | linux.STATX_BTIME)
if !d.cachedMetadataAuthoritative() && opts.Mask&validMask != 0 && opts.Sync != linux.AT_STATX_DONT_SYNC {
- // TODO(jamieliu): Use specialFileFD.handle.file for the getattr if
- // available?
- if err := d.updateFromGetattr(ctx); err != nil {
- return linux.Statx{}, err
+ if d.fs.opts.lisaEnabled {
+ // Use specialFileFD.handle.fileLisa for the Stat if available, for the
+ // same reason that we try to use open FD in updateFromStatLisaLocked().
+ var fdLisa *lisafs.ClientFD
+ if sffd, ok := fd.vfsfd.Impl().(*specialFileFD); ok {
+ fdLisa = &sffd.handle.fdLisa
+ }
+ d.metadataMu.Lock()
+ err := d.updateFromStatLisaLocked(ctx, fdLisa)
+ d.metadataMu.Unlock()
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ } else {
+ // Use specialFileFD.handle.file for the getattr if available, for the
+ // same reason that we try to use open file handles in
+ // dentry.updateFromGetattrLocked().
+ var file p9file
+ if sffd, ok := fd.vfsfd.Impl().(*specialFileFD); ok {
+ file = sffd.handle.file
+ }
+ d.metadataMu.Lock()
+ err := d.updateFromGetattrLocked(ctx, file)
+ d.metadataMu.Unlock()
+ if err != nil {
+ return linux.Statx{}, err
+ }
}
}
var stat linux.Statx
@@ -2069,7 +2507,7 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions)
// ListXattr implements vfs.FileDescriptionImpl.ListXattr.
func (fd *fileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) {
- return fd.dentry().listXattr(ctx, auth.CredentialsFromContext(ctx), size)
+ return fd.dentry().listXattr(ctx, size)
}
// GetXattr implements vfs.FileDescriptionImpl.GetXattr.
diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go
index 806392d50..d5cc73f33 100644
--- a/pkg/sentry/fsimpl/gofer/gofer_test.go
+++ b/pkg/sentry/fsimpl/gofer/gofer_test.go
@@ -33,6 +33,7 @@ func TestDestroyIdempotent(t *testing.T) {
},
syncableDentries: make(map[*dentry]struct{}),
inoByQIDPath: make(map[uint64]uint64),
+ inoByKey: make(map[inoKey]uint64),
}
attr := &p9.Attr{
diff --git a/pkg/sentry/fsimpl/gofer/handle.go b/pkg/sentry/fsimpl/gofer/handle.go
index 5c57f6fea..394aecd62 100644
--- a/pkg/sentry/fsimpl/gofer/handle.go
+++ b/pkg/sentry/fsimpl/gofer/handle.go
@@ -17,18 +17,23 @@ package gofer
import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/lisafs"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/hostfd"
+ "gvisor.dev/gvisor/pkg/sync"
)
// handle represents a remote "open file descriptor", consisting of an opened
// fid (p9.File) and optionally a host file descriptor.
//
+// If lisafs is being used, fdLisa points to an open file on the server.
+//
// These are explicitly not savable.
type handle struct {
- file p9file
- fd int32 // -1 if unavailable
+ fdLisa lisafs.ClientFD
+ file p9file
+ fd int32 // -1 if unavailable
}
// Preconditions: read || write.
@@ -64,13 +69,47 @@ func openHandle(ctx context.Context, file p9file, read, write, trunc bool) (hand
}, nil
}
+// Preconditions: read || write.
+func openHandleLisa(ctx context.Context, fdLisa lisafs.ClientFD, read, write, trunc bool) (handle, error) {
+ var flags uint32
+ switch {
+ case read && write:
+ flags = unix.O_RDWR
+ case read:
+ flags = unix.O_RDONLY
+ case write:
+ flags = unix.O_WRONLY
+ default:
+ panic("tried to open unreadable and unwritable handle")
+ }
+ if trunc {
+ flags |= unix.O_TRUNC
+ }
+ openFD, hostFD, err := fdLisa.OpenAt(ctx, flags)
+ if err != nil {
+ return handle{fd: -1}, err
+ }
+ h := handle{
+ fdLisa: fdLisa.Client().NewFD(openFD),
+ fd: int32(hostFD),
+ }
+ return h, nil
+}
+
func (h *handle) isOpen() bool {
+ if h.fdLisa.Client() != nil {
+ return h.fdLisa.Ok()
+ }
return !h.file.isNil()
}
func (h *handle) close(ctx context.Context) {
- h.file.close(ctx)
- h.file = p9file{}
+ if h.fdLisa.Client() != nil {
+ h.fdLisa.CloseBatched(ctx)
+ } else {
+ h.file.close(ctx)
+ h.file = p9file{}
+ }
if h.fd >= 0 {
unix.Close(int(h.fd))
h.fd = -1
@@ -88,19 +127,27 @@ func (h *handle) readToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offs
return n, err
}
if dsts.NumBlocks() == 1 && !dsts.Head().NeedSafecopy() {
- n, err := h.file.readAt(ctx, dsts.Head().ToSlice(), offset)
- return uint64(n), err
+ if h.fdLisa.Client() != nil {
+ return h.fdLisa.Read(ctx, dsts.Head().ToSlice(), offset)
+ }
+ return h.file.readAt(ctx, dsts.Head().ToSlice(), offset)
}
// Buffer the read since p9.File.ReadAt() takes []byte.
buf := make([]byte, dsts.NumBytes())
- n, err := h.file.readAt(ctx, buf, offset)
+ var n uint64
+ var err error
+ if h.fdLisa.Client() != nil {
+ n, err = h.fdLisa.Read(ctx, buf, offset)
+ } else {
+ n, err = h.file.readAt(ctx, buf, offset)
+ }
if n == 0 {
return 0, err
}
if cp, cperr := safemem.CopySeq(dsts, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:n]))); cperr != nil {
return cp, cperr
}
- return uint64(n), err
+ return n, err
}
func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) {
@@ -114,8 +161,10 @@ func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, o
return n, err
}
if srcs.NumBlocks() == 1 && !srcs.Head().NeedSafecopy() {
- n, err := h.file.writeAt(ctx, srcs.Head().ToSlice(), offset)
- return uint64(n), err
+ if h.fdLisa.Client() != nil {
+ return h.fdLisa.Write(ctx, srcs.Head().ToSlice(), offset)
+ }
+ return h.file.writeAt(ctx, srcs.Head().ToSlice(), offset)
}
// Buffer the write since p9.File.WriteAt() takes []byte.
buf := make([]byte, srcs.NumBytes())
@@ -123,10 +172,56 @@ func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, o
if cp == 0 {
return 0, cperr
}
- n, err := h.file.writeAt(ctx, buf[:cp], offset)
+ var n uint64
+ var err error
+ if h.fdLisa.Client() != nil {
+ n, err = h.fdLisa.Write(ctx, buf[:cp], offset)
+ } else {
+ n, err = h.file.writeAt(ctx, buf[:cp], offset)
+ }
// err takes precedence over cperr.
if err != nil {
- return uint64(n), err
+ return n, err
}
- return uint64(n), cperr
+ return n, cperr
+}
+
+type handleReadWriter struct {
+ ctx context.Context
+ h *handle
+ off uint64
+}
+
+var handleReadWriterPool = sync.Pool{
+ New: func() interface{} {
+ return &handleReadWriter{}
+ },
+}
+
+func getHandleReadWriter(ctx context.Context, h *handle, offset int64) *handleReadWriter {
+ rw := handleReadWriterPool.Get().(*handleReadWriter)
+ rw.ctx = ctx
+ rw.h = h
+ rw.off = uint64(offset)
+ return rw
+}
+
+func putHandleReadWriter(rw *handleReadWriter) {
+ rw.ctx = nil
+ rw.h = nil
+ handleReadWriterPool.Put(rw)
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (rw *handleReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ n, err := rw.h.readToBlocksAt(rw.ctx, dsts, rw.off)
+ rw.off += n
+ return n, err
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+func (rw *handleReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ n, err := rw.h.writeFromBlocksAt(rw.ctx, srcs, rw.off)
+ rw.off += n
+ return n, err
}
diff --git a/pkg/sentry/fsimpl/gofer/host_named_pipe.go b/pkg/sentry/fsimpl/gofer/host_named_pipe.go
index 398288ee3..505916a57 100644
--- a/pkg/sentry/fsimpl/gofer/host_named_pipe.go
+++ b/pkg/sentry/fsimpl/gofer/host_named_pipe.go
@@ -22,7 +22,6 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
- "gvisor.dev/gvisor/pkg/syserror"
)
// Global pipe used by blockUntilNonblockingPipeHasWriter since we can't create
@@ -109,6 +108,6 @@ func sleepBetweenNamedPipeOpenChecks(ctx context.Context) error {
return nil
case <-cancel:
ctx.SleepFinish(false)
- return syserror.ErrInterrupted
+ return linuxerr.ErrInterrupted
}
}
diff --git a/pkg/sentry/fsimpl/gofer/p9file.go b/pkg/sentry/fsimpl/gofer/p9file.go
index b0a429d42..0d97b60fd 100644
--- a/pkg/sentry/fsimpl/gofer/p9file.go
+++ b/pkg/sentry/fsimpl/gofer/p9file.go
@@ -16,9 +16,9 @@ package gofer
import (
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/p9"
- "gvisor.dev/gvisor/pkg/syserror"
)
// p9file is a wrapper around p9.File that provides methods that are
@@ -59,7 +59,7 @@ func (f p9file) walkGetAttrOne(ctx context.Context, name string) (p9.QID, p9file
if newfile != nil {
p9file{newfile}.close(ctx)
}
- return p9.QID{}, p9file{}, p9.AttrMask{}, p9.Attr{}, syserror.EIO
+ return p9.QID{}, p9file{}, p9.AttrMask{}, p9.Attr{}, linuxerr.EIO
}
return qids[0], p9file{newfile}, attrMask, attr, nil
}
@@ -141,18 +141,18 @@ func (f p9file) open(ctx context.Context, flags p9.OpenFlags) (*fd.FD, p9.QID, u
return fdobj, qid, iounit, err
}
-func (f p9file) readAt(ctx context.Context, p []byte, offset uint64) (int, error) {
+func (f p9file) readAt(ctx context.Context, p []byte, offset uint64) (uint64, error) {
ctx.UninterruptibleSleepStart(false)
n, err := f.file.ReadAt(p, offset)
ctx.UninterruptibleSleepFinish(false)
- return n, err
+ return uint64(n), err
}
-func (f p9file) writeAt(ctx context.Context, p []byte, offset uint64) (int, error) {
+func (f p9file) writeAt(ctx context.Context, p []byte, offset uint64) (uint64, error) {
ctx.UninterruptibleSleepStart(false)
n, err := f.file.WriteAt(p, offset)
ctx.UninterruptibleSleepFinish(false)
- return n, err
+ return uint64(n), err
}
func (f p9file) fsync(ctx context.Context) error {
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index 947dbe05f..874f9873d 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -98,6 +98,12 @@ func (fd *regularFileFD) OnClose(ctx context.Context) error {
}
d.handleMu.RLock()
defer d.handleMu.RUnlock()
+ if d.fs.opts.lisaEnabled {
+ if !d.writeFDLisa.Ok() {
+ return nil
+ }
+ return d.writeFDLisa.Flush(ctx)
+ }
if d.writeFile.isNil() {
return nil
}
@@ -110,6 +116,9 @@ func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint
return d.doAllocate(ctx, offset, length, func() error {
d.handleMu.RLock()
defer d.handleMu.RUnlock()
+ if d.fs.opts.lisaEnabled {
+ return d.writeFDLisa.Allocate(ctx, mode, offset, length)
+ }
return d.writeFile.allocate(ctx, p9.ToAllocateMode(mode), offset, length)
})
}
@@ -282,8 +291,19 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off
// changes to the host.
if newMode := vfs.ClearSUIDAndSGID(oldMode); newMode != oldMode {
atomic.StoreUint32(&d.mode, newMode)
- if err := d.file.setAttr(ctx, p9.SetAttrMask{Permissions: true}, p9.SetAttr{Permissions: p9.FileMode(newMode)}); err != nil {
- return 0, offset, err
+ if d.fs.opts.lisaEnabled {
+ stat := linux.Statx{Mask: linux.STATX_MODE, Mode: uint16(newMode)}
+ failureMask, failureErr, err := d.controlFDLisa.SetStat(ctx, &stat)
+ if err != nil {
+ return 0, offset, err
+ }
+ if failureMask != 0 {
+ return 0, offset, failureErr
+ }
+ } else {
+ if err := d.file.setAttr(ctx, p9.SetAttrMask{Permissions: true}, p9.SetAttr{Permissions: p9.FileMode(newMode)}); err != nil {
+ return 0, offset, err
+ }
}
}
}
@@ -677,7 +697,7 @@ func regularFileSeekLocked(ctx context.Context, d *dentry, fdOffset, offset int6
// Sync implements vfs.FileDescriptionImpl.Sync.
func (fd *regularFileFD) Sync(ctx context.Context) error {
- return fd.dentry().syncCachedFile(ctx, false /* lowSyncExpectations */)
+ return fd.dentry().syncCachedFile(ctx, false /* forFilesystemSync */, nil /* accFsyncFDIDsLisa */)
}
// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
diff --git a/pkg/sentry/fsimpl/gofer/revalidate.go b/pkg/sentry/fsimpl/gofer/revalidate.go
index 226790a11..5d4009832 100644
--- a/pkg/sentry/fsimpl/gofer/revalidate.go
+++ b/pkg/sentry/fsimpl/gofer/revalidate.go
@@ -15,7 +15,9 @@
package gofer
import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -234,28 +236,54 @@ func (fs *filesystem) revalidateHelper(ctx context.Context, vfsObj *vfs.VirtualF
}
// Lock metadata on all dentries *before* getting attributes for them.
state.lockAllMetadata()
- stats, err := state.start.file.multiGetAttr(ctx, state.names)
- if err != nil {
- return err
+
+ var (
+ stats []p9.FullStat
+ statsLisa []linux.Statx
+ numStats int
+ )
+ if fs.opts.lisaEnabled {
+ var err error
+ statsLisa, err = state.start.controlFDLisa.WalkStat(ctx, state.names)
+ if err != nil {
+ return err
+ }
+ numStats = len(statsLisa)
+ } else {
+ var err error
+ stats, err = state.start.file.multiGetAttr(ctx, state.names)
+ if err != nil {
+ return err
+ }
+ numStats = len(stats)
}
i := -1
for d := state.popFront(); d != nil; d = state.popFront() {
i++
- found := i < len(stats)
+ found := i < numStats
if i == 0 && len(state.names[0]) == 0 {
if found && !d.isSynthetic() {
// First dentry is where the search is starting, just update attributes
// since it cannot be replaced.
- d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) // +checklocksforce: acquired by lockAllMetadata.
+ if fs.opts.lisaEnabled {
+ d.updateFromLisaStatLocked(&statsLisa[i]) // +checklocksforce: acquired by lockAllMetadata.
+ } else {
+ d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) // +checklocksforce: acquired by lockAllMetadata.
+ }
}
d.metadataMu.Unlock() // +checklocksforce: see above.
continue
}
- // Note that synthetic dentries will always fails the comparison check
- // below.
- if !found || d.qidPath != stats[i].QID.Path {
+ // Note that synthetic dentries will always fail this comparison check.
+ var shouldInvalidate bool
+ if fs.opts.lisaEnabled {
+ shouldInvalidate = !found || d.inoKey != inoKeyFromStat(&statsLisa[i])
+ } else {
+ shouldInvalidate = !found || d.qidPath != stats[i].QID.Path
+ }
+ if shouldInvalidate {
d.metadataMu.Unlock() // +checklocksforce: see above.
if !found && d.isSynthetic() {
// We have a synthetic file, and no remote file has arisen to replace
@@ -298,7 +326,11 @@ func (fs *filesystem) revalidateHelper(ctx context.Context, vfsObj *vfs.VirtualF
}
// The file at this path hasn't changed. Just update cached metadata.
- d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) // +checklocksforce: see above.
+ if fs.opts.lisaEnabled {
+ d.updateFromLisaStatLocked(&statsLisa[i]) // +checklocksforce: see above.
+ } else {
+ d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) // +checklocksforce: see above.
+ }
d.metadataMu.Unlock()
}
diff --git a/pkg/sentry/fsimpl/gofer/save_restore.go b/pkg/sentry/fsimpl/gofer/save_restore.go
index e67422a2f..475322527 100644
--- a/pkg/sentry/fsimpl/gofer/save_restore.go
+++ b/pkg/sentry/fsimpl/gofer/save_restore.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/lisafs"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/safemem"
@@ -112,10 +113,19 @@ func (d *dentry) prepareSaveRecursive(ctx context.Context) error {
return err
}
}
- if !d.readFile.isNil() || !d.writeFile.isNil() {
- d.fs.savedDentryRW[d] = savedDentryRW{
- read: !d.readFile.isNil(),
- write: !d.writeFile.isNil(),
+ if d.fs.opts.lisaEnabled {
+ if d.readFDLisa.Ok() || d.writeFDLisa.Ok() {
+ d.fs.savedDentryRW[d] = savedDentryRW{
+ read: d.readFDLisa.Ok(),
+ write: d.writeFDLisa.Ok(),
+ }
+ }
+ } else {
+ if !d.readFile.isNil() || !d.writeFile.isNil() {
+ d.fs.savedDentryRW[d] = savedDentryRW{
+ read: !d.readFile.isNil(),
+ write: !d.writeFile.isNil(),
+ }
}
}
d.dirMu.Lock()
@@ -158,6 +168,10 @@ func (d *dentryPlatformFile) afterLoad() {
// afterLoad is invoked by stateify.
func (fd *specialFileFD) afterLoad() {
fd.handle.fd = -1
+ if fd.hostFileMapper.IsInited() {
+ // Ensure that we don't call fd.hostFileMapper.Init() again.
+ fd.hostFileMapperInitOnce.Do(func() {})
+ }
}
// CompleteRestore implements
@@ -173,25 +187,37 @@ func (fs *filesystem) CompleteRestore(ctx context.Context, opts vfs.CompleteRest
return fmt.Errorf("no server FD available for filesystem with unique ID %q", fs.iopts.UniqueID)
}
fs.opts.fd = fd
- if err := fs.dial(ctx); err != nil {
- return err
- }
fs.inoByQIDPath = make(map[uint64]uint64)
+ fs.inoByKey = make(map[inoKey]uint64)
- // Restore the filesystem root.
- ctx.UninterruptibleSleepStart(false)
- attached, err := fs.client.Attach(fs.opts.aname)
- ctx.UninterruptibleSleepFinish(false)
- if err != nil {
- return err
- }
- attachFile := p9file{attached}
- qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask())
- if err != nil {
- return err
- }
- if err := fs.root.restoreFile(ctx, attachFile, qid, attrMask, &attr, &opts); err != nil {
- return err
+ if fs.opts.lisaEnabled {
+ rootInode, err := fs.initClientLisa(ctx)
+ if err != nil {
+ return err
+ }
+ if err := fs.root.restoreFileLisa(ctx, rootInode, &opts); err != nil {
+ return err
+ }
+ } else {
+ if err := fs.dial(ctx); err != nil {
+ return err
+ }
+
+ // Restore the filesystem root.
+ ctx.UninterruptibleSleepStart(false)
+ attached, err := fs.client.Attach(fs.opts.aname)
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ return err
+ }
+ attachFile := p9file{attached}
+ qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask())
+ if err != nil {
+ return err
+ }
+ if err := fs.root.restoreFile(ctx, attachFile, qid, attrMask, &attr, &opts); err != nil {
+ return err
+ }
}
// Restore remaining dentries.
@@ -279,6 +305,55 @@ func (d *dentry) restoreFile(ctx context.Context, file p9file, qid p9.QID, attrM
return nil
}
+func (d *dentry) restoreFileLisa(ctx context.Context, inode *lisafs.Inode, opts *vfs.CompleteRestoreOptions) error {
+ d.controlFDLisa = d.fs.clientLisa.NewFD(inode.ControlFD)
+
+ // Gofers do not preserve inoKey across checkpoint/restore, so:
+ //
+ // - We must assume that the remote filesystem did not change in a way that
+ // would invalidate dentries, since we can't revalidate dentries by
+ // checking inoKey.
+ //
+ // - We need to associate the new inoKey with the existing d.ino.
+ d.inoKey = inoKeyFromStat(&inode.Stat)
+ d.fs.inoMu.Lock()
+ d.fs.inoByKey[d.inoKey] = d.ino
+ d.fs.inoMu.Unlock()
+
+ // Check metadata stability before updating metadata.
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ if d.isRegularFile() {
+ if opts.ValidateFileSizes {
+ if inode.Stat.Mask&linux.STATX_SIZE != 0 {
+ return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: file size not available", genericDebugPathname(d))
+ }
+ if d.size != inode.Stat.Size {
+ return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: size changed from %d to %d", genericDebugPathname(d), d.size, inode.Stat.Size)
+ }
+ }
+ if opts.ValidateFileModificationTimestamps {
+ if inode.Stat.Mask&linux.STATX_MTIME != 0 {
+ return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime not available", genericDebugPathname(d))
+ }
+ if want := dentryTimestampFromLisa(inode.Stat.Mtime); d.mtime != want {
+ return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime changed from %+v to %+v", genericDebugPathname(d), linux.NsecToStatxTimestamp(d.mtime), linux.NsecToStatxTimestamp(want))
+ }
+ }
+ }
+ if !d.cachedMetadataAuthoritative() {
+ d.updateFromLisaStatLocked(&inode.Stat)
+ }
+
+ if rw, ok := d.fs.savedDentryRW[d]; ok {
+ if err := d.ensureSharedHandle(ctx, rw.read, rw.write, false /* trunc */); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
// Preconditions: d is not synthetic.
func (d *dentry) restoreDescendantsRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error {
for _, child := range d.children {
@@ -301,19 +376,35 @@ func (d *dentry) restoreDescendantsRecursive(ctx context.Context, opts *vfs.Comp
// only be detected by checking filesystem.syncableDentries). d.parent has been
// restored.
func (d *dentry) restoreRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error {
- qid, file, attrMask, attr, err := d.parent.file.walkGetAttrOne(ctx, d.name)
- if err != nil {
- return err
- }
- if err := d.restoreFile(ctx, file, qid, attrMask, &attr, opts); err != nil {
- return err
+ if d.fs.opts.lisaEnabled {
+ inode, err := d.parent.controlFDLisa.Walk(ctx, d.name)
+ if err != nil {
+ return err
+ }
+ if err := d.restoreFileLisa(ctx, inode, opts); err != nil {
+ return err
+ }
+ } else {
+ qid, file, attrMask, attr, err := d.parent.file.walkGetAttrOne(ctx, d.name)
+ if err != nil {
+ return err
+ }
+ if err := d.restoreFile(ctx, file, qid, attrMask, &attr, opts); err != nil {
+ return err
+ }
}
return d.restoreDescendantsRecursive(ctx, opts)
}
func (fd *specialFileFD) completeRestore(ctx context.Context) error {
d := fd.dentry()
- h, err := openHandle(ctx, d.file, fd.vfsfd.IsReadable(), fd.vfsfd.IsWritable(), false /* trunc */)
+ var h handle
+ var err error
+ if d.fs.opts.lisaEnabled {
+ h, err = openHandleLisa(ctx, d.controlFDLisa, fd.vfsfd.IsReadable(), fd.vfsfd.IsWritable(), false /* trunc */)
+ } else {
+ h, err = openHandle(ctx, d.file, fd.vfsfd.IsReadable(), fd.vfsfd.IsWritable(), false /* trunc */)
+ }
if err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go
index fe15f8583..86ab70453 100644
--- a/pkg/sentry/fsimpl/gofer/socket.go
+++ b/pkg/sentry/fsimpl/gofer/socket.go
@@ -59,11 +59,6 @@ func sockTypeToP9(t linux.SockType) (p9.ConnectFlags, bool) {
// BidirectionalConnect implements ConnectableEndpoint.BidirectionalConnect.
func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.ConnectingEndpoint, returnConnect func(transport.Receiver, transport.ConnectedEndpoint)) *syserr.Error {
- cf, ok := sockTypeToP9(ce.Type())
- if !ok {
- return syserr.ErrConnectionRefused
- }
-
// No lock ordering required as only the ConnectingEndpoint has a mutex.
ce.Lock()
@@ -77,7 +72,7 @@ func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.Connec
return syserr.ErrInvalidEndpointState
}
- c, err := e.newConnectedEndpoint(ctx, cf, ce.WaiterQueue())
+ c, err := e.newConnectedEndpoint(ctx, ce.Type(), ce.WaiterQueue())
if err != nil {
ce.Unlock()
return err
@@ -95,7 +90,7 @@ func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.Connec
// UnidirectionalConnect implements
// transport.BoundEndpoint.UnidirectionalConnect.
func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.ConnectedEndpoint, *syserr.Error) {
- c, err := e.newConnectedEndpoint(ctx, p9.DgramSocket, &waiter.Queue{})
+ c, err := e.newConnectedEndpoint(ctx, linux.SOCK_DGRAM, &waiter.Queue{})
if err != nil {
return nil, err
}
@@ -111,25 +106,39 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect
return c, nil
}
-func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFlags, queue *waiter.Queue) (*host.SCMConnectedEndpoint, *syserr.Error) {
- hostFile, err := e.dentry.file.connect(ctx, flags)
- if err != nil {
+func (e *endpoint) newConnectedEndpoint(ctx context.Context, sockType linux.SockType, queue *waiter.Queue) (*host.SCMConnectedEndpoint, *syserr.Error) {
+ if e.dentry.fs.opts.lisaEnabled {
+ hostSockFD, err := e.dentry.controlFDLisa.Connect(ctx, sockType)
+ if err != nil {
+ return nil, syserr.ErrConnectionRefused
+ }
+
+ c, serr := host.NewSCMEndpoint(ctx, hostSockFD, queue, e.path)
+ if serr != nil {
+ unix.Close(hostSockFD)
+ log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v sockType %d: %v", e.dentry.file, sockType, serr)
+ return nil, serr
+ }
+ return c, nil
+ }
+
+ flags, ok := sockTypeToP9(sockType)
+ if !ok {
return nil, syserr.ErrConnectionRefused
}
- // Dup the fd so that the new endpoint can manage its lifetime.
- hostFD, err := unix.Dup(hostFile.FD())
+ hostFile, err := e.dentry.file.connect(ctx, flags)
if err != nil {
- log.Warningf("Could not dup host socket fd %d: %v", hostFile.FD(), err)
- return nil, syserr.FromError(err)
+ return nil, syserr.ErrConnectionRefused
}
- // After duplicating, we no longer need hostFile.
- hostFile.Close()
- c, serr := host.NewSCMEndpoint(ctx, hostFD, queue, e.path)
+ c, serr := host.NewSCMEndpoint(ctx, hostFile.FD(), queue, e.path)
if serr != nil {
- log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.dentry.file, flags, serr)
+ hostFile.Close()
+ log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v sockType %d: %v", e.dentry.file, sockType, serr)
return nil, serr
}
+ // Ownership has been transferred to c.
+ hostFile.Release()
return c, nil
}
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
index 4b59c1c3c..c568bbfd2 100644
--- a/pkg/sentry/fsimpl/gofer/special_file.go
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -22,13 +22,16 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/lisafs"
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fsmetric"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -76,6 +79,16 @@ type specialFileFD struct {
bufMu sync.Mutex `state:"nosave"`
haveBuf uint32
buf []byte
+
+ // If handle.fd >= 0, hostFileMapper caches mappings of handle.fd, and
+ // hostFileMapperInitOnce is used to initialize it on first use.
+ hostFileMapperInitOnce sync.Once `state:"nosave"`
+ hostFileMapper fsutil.HostFileMapper
+
+ // If handle.fd >= 0, fileRefs counts references on memmap.File offsets.
+ // fileRefs is protected by fileRefsMu.
+ fileRefsMu sync.Mutex `state:"nosave"`
+ fileRefs fsutil.FrameRefSet
}
func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, flags uint32) (*specialFileFD, error) {
@@ -137,6 +150,9 @@ func (fd *specialFileFD) OnClose(ctx context.Context) error {
if !fd.vfsfd.IsWritable() {
return nil
}
+ if fs := fd.filesystem(); fs.opts.lisaEnabled {
+ return fd.handle.fdLisa.Flush(ctx)
+ }
return fd.handle.file.flush(ctx)
}
@@ -172,6 +188,9 @@ func (fd *specialFileFD) Allocate(ctx context.Context, mode, offset, length uint
if fd.isRegularFile {
d := fd.dentry()
return d.doAllocate(ctx, offset, length, func() error {
+ if d.fs.opts.lisaEnabled {
+ return fd.handle.fdLisa.Allocate(ctx, mode, offset, length)
+ }
return fd.handle.file.allocate(ctx, p9.ToAllocateMode(mode), offset, length)
})
}
@@ -230,23 +249,13 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs
}
}
- // Going through dst.CopyOutFrom() would hold MM locks around file
- // operations of unknown duration. For regularFileFD, doing so is necessary
- // to support mmap due to lock ordering; MM locks precede dentry.dataMu.
- // That doesn't hold here since specialFileFD doesn't client-cache data.
- // Just buffer the read instead.
- buf := make([]byte, dst.NumBytes())
- n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset))
+ rw := getHandleReadWriter(ctx, &fd.handle, offset)
+ n, err := dst.CopyOutFrom(ctx, rw)
+ putHandleReadWriter(rw)
if linuxerr.Equals(linuxerr.EAGAIN, err) {
- err = syserror.ErrWouldBlock
- }
- if n == 0 {
- return bufN, err
+ err = linuxerr.ErrWouldBlock
}
- if cp, cperr := dst.CopyOut(ctx, buf[:n]); cperr != nil {
- return bufN + int64(cp), cperr
- }
- return bufN + int64(n), err
+ return bufN + n, err
}
// Read implements vfs.FileDescriptionImpl.Read.
@@ -317,20 +326,15 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off
}
}
- // Do a buffered write. See rationale in PRead.
- buf := make([]byte, src.NumBytes())
- copied, copyErr := src.CopyIn(ctx, buf)
- if copied == 0 && copyErr != nil {
- // Only return the error if we didn't get any data.
- return 0, offset, copyErr
- }
- n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:copied])), uint64(offset))
+ rw := getHandleReadWriter(ctx, &fd.handle, offset)
+ n, err := src.CopyInTo(ctx, rw)
+ putHandleReadWriter(rw)
if linuxerr.Equals(linuxerr.EAGAIN, err) {
- err = syserror.ErrWouldBlock
+ err = linuxerr.ErrWouldBlock
}
// Update offset if the offset is valid.
if offset >= 0 {
- offset += int64(n)
+ offset += n
}
// Update file size for regular files.
if fd.isRegularFile {
@@ -341,10 +345,7 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off
atomic.StoreUint64(&d.size, uint64(offset))
}
}
- if err != nil {
- return int64(n), offset, err
- }
- return int64(n), offset, copyErr
+ return int64(n), offset, err
}
// Write implements vfs.FileDescriptionImpl.Write.
@@ -377,10 +378,10 @@ func (fd *specialFileFD) Seek(ctx context.Context, offset int64, whence int32) (
// Sync implements vfs.FileDescriptionImpl.Sync.
func (fd *specialFileFD) Sync(ctx context.Context) error {
- return fd.sync(ctx, false /* forFilesystemSync */)
+ return fd.sync(ctx, false /* forFilesystemSync */, nil /* accFsyncFDIDsLisa */)
}
-func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool) error {
+func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool, accFsyncFDIDsLisa *[]lisafs.FDID) error {
// Locks to ensure it didn't race with fd.Release().
fd.releaseMu.RLock()
defer fd.releaseMu.RUnlock()
@@ -397,6 +398,13 @@ func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool) error
ctx.UninterruptibleSleepFinish(false)
return err
}
+ if fs := fd.filesystem(); fs.opts.lisaEnabled {
+ if accFsyncFDIDsLisa != nil {
+ *accFsyncFDIDsLisa = append(*accFsyncFDIDsLisa, fd.handle.fdLisa.ID())
+ return nil
+ }
+ return fd.handle.fdLisa.Sync(ctx)
+ }
return fd.handle.file.fsync(ctx)
}()
if err != nil {
@@ -412,3 +420,85 @@ func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool) error
}
return nil
}
+
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *specialFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ if fd.handle.fd < 0 || fd.filesystem().opts.forcePageCache {
+ return linuxerr.ENODEV
+ }
+ // After this point, fd may be used as a memmap.Mappable and memmap.File.
+ fd.hostFileMapperInitOnce.Do(fd.hostFileMapper.Init)
+ return vfs.GenericConfigureMMap(&fd.vfsfd, fd, opts)
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (fd *specialFileFD) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error {
+ fd.hostFileMapper.IncRefOn(memmap.MappableRange{offset, offset + uint64(ar.Length())})
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (fd *specialFileFD) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) {
+ fd.hostFileMapper.DecRefOn(memmap.MappableRange{offset, offset + uint64(ar.Length())})
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (fd *specialFileFD) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error {
+ return fd.AddMapping(ctx, ms, dstAR, offset, writable)
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (fd *specialFileFD) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) {
+ mr := optional
+ if fd.filesystem().opts.limitHostFDTranslation {
+ mr = maxFillRange(required, optional)
+ }
+ return []memmap.Translation{
+ {
+ Source: mr,
+ File: fd,
+ Offset: mr.Start,
+ Perms: hostarch.AnyAccess,
+ },
+ }, nil
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (fd *specialFileFD) InvalidateUnsavable(ctx context.Context) error {
+ return nil
+}
+
+// IncRef implements memmap.File.IncRef.
+func (fd *specialFileFD) IncRef(fr memmap.FileRange) {
+ fd.fileRefsMu.Lock()
+ defer fd.fileRefsMu.Unlock()
+ fd.fileRefs.IncRefAndAccount(fr)
+}
+
+// DecRef implements memmap.File.DecRef.
+func (fd *specialFileFD) DecRef(fr memmap.FileRange) {
+ fd.fileRefsMu.Lock()
+ defer fd.fileRefsMu.Unlock()
+ fd.fileRefs.DecRefAndAccount(fr)
+}
+
+// MapInternal implements memmap.File.MapInternal.
+func (fd *specialFileFD) MapInternal(fr memmap.FileRange, at hostarch.AccessType) (safemem.BlockSeq, error) {
+ fd.requireHostFD()
+ return fd.hostFileMapper.MapInternal(fr, int(fd.handle.fd), at.Write)
+}
+
+// FD implements memmap.File.FD.
+func (fd *specialFileFD) FD() int {
+ fd.requireHostFD()
+ return int(fd.handle.fd)
+}
+
+func (fd *specialFileFD) requireHostFD() {
+ if fd.handle.fd < 0 {
+ // This is possible if fd was successfully mmapped before saving, then
+ // was restored without a host FD. This is unrecoverable: without a
+ // host FD, we can't mmap this file post-restore.
+ panic("gofer.specialFileFD can no longer be memory-mapped without a host FD")
+ }
+}
diff --git a/pkg/sentry/fsimpl/gofer/symlink.go b/pkg/sentry/fsimpl/gofer/symlink.go
index dbd834c67..27d9be5c4 100644
--- a/pkg/sentry/fsimpl/gofer/symlink.go
+++ b/pkg/sentry/fsimpl/gofer/symlink.go
@@ -35,7 +35,13 @@ func (d *dentry) readlink(ctx context.Context, mnt *vfs.Mount) (string, error) {
return target, nil
}
}
- target, err := d.file.readlink(ctx)
+ var target string
+ var err error
+ if d.fs.opts.lisaEnabled {
+ target, err = d.controlFDLisa.ReadLinkAt(ctx)
+ } else {
+ target, err = d.file.readlink(ctx)
+ }
if d.fs.opts.interop != InteropModeShared {
if err == nil {
d.haveTarget = true
diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go
index 9cbe805b9..07940b225 100644
--- a/pkg/sentry/fsimpl/gofer/time.go
+++ b/pkg/sentry/fsimpl/gofer/time.go
@@ -17,6 +17,7 @@ package gofer
import (
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/vfs"
)
@@ -24,6 +25,10 @@ func dentryTimestampFromP9(s, ns uint64) int64 {
return int64(s*1e9 + ns)
}
+func dentryTimestampFromLisa(t linux.StatxTimestamp) int64 {
+ return t.Sec*1e9 + int64(t.Nsec)
+}
+
// Preconditions: d.cachedMetadataAuthoritative() == true.
func (d *dentry) touchAtime(mnt *vfs.Mount) {
if mnt.Flags.NoATime || mnt.ReadOnly() {
diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD
index 476545d00..180a35583 100644
--- a/pkg/sentry/fsimpl/host/BUILD
+++ b/pkg/sentry/fsimpl/host/BUILD
@@ -70,7 +70,6 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/sync",
"//pkg/syserr",
- "//pkg/syserror",
"//pkg/tcpip",
"//pkg/unet",
"//pkg/usermem",
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index 89aa7b3d9..984c6e8ee 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -37,7 +37,6 @@ import (
unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -712,7 +711,7 @@ func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts
if total != 0 {
err = nil
} else {
- err = syserror.ErrWouldBlock
+ err = linuxerr.ErrWouldBlock
}
}
return total, err
@@ -766,7 +765,7 @@ func (f *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opt
if !i.seekable {
n, err := f.writeToHostFD(ctx, src, -1, opts.Flags)
if isBlockError(err) {
- err = syserror.ErrWouldBlock
+ err = linuxerr.ErrWouldBlock
}
return n, err
}
diff --git a/pkg/sentry/fsimpl/host/tty.go b/pkg/sentry/fsimpl/host/tty.go
index 7f6ce4ee5..04ac73255 100644
--- a/pkg/sentry/fsimpl/host/tty.go
+++ b/pkg/sentry/fsimpl/host/tty.go
@@ -24,7 +24,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/unimpl"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -346,7 +345,7 @@ func (t *TTYFileDescription) checkChange(ctx context.Context, sig linux.Signal)
// If the signal is SIGTTIN, then we are attempting to read
// from the TTY. Don't send the signal and return EIO.
if sig == linux.SIGTTIN {
- return syserror.EIO
+ return linuxerr.EIO
}
// Otherwise, we are writing or changing terminal state. This is allowed.
@@ -355,7 +354,7 @@ func (t *TTYFileDescription) checkChange(ctx context.Context, sig linux.Signal)
// If the process group is an orphan, return EIO.
if pg.IsOrphan() {
- return syserror.EIO
+ return linuxerr.EIO
}
// Otherwise, send the signal to the process group and return ERESTARTSYS.
@@ -368,5 +367,5 @@ func (t *TTYFileDescription) checkChange(ctx context.Context, sig linux.Signal)
//
// Linux ignores the result of kill_pgrp().
_ = pg.SendSignal(kernel.SignalInfoPriv(sig))
- return syserror.ERESTARTSYS
+ return linuxerr.ERESTARTSYS
}
diff --git a/pkg/sentry/fsimpl/host/util.go b/pkg/sentry/fsimpl/host/util.go
index 95d7ebe2e..9850f3f41 100644
--- a/pkg/sentry/fsimpl/host/util.go
+++ b/pkg/sentry/fsimpl/host/util.go
@@ -42,7 +42,7 @@ func timespecToStatxTimestamp(ts unix.Timespec) linux.StatxTimestamp {
}
// isBlockError checks if an error is EAGAIN or EWOULDBLOCK.
-// If so, they can be transformed into syserror.ErrWouldBlock.
+// If so, they can be transformed into linuxerr.ErrWouldBlock.
func isBlockError(err error) bool {
return linuxerr.Equals(linuxerr.EAGAIN, err) || linuxerr.Equals(linuxerr.EWOULDBLOCK, err)
}
diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD
index d53937db6..4b577ea43 100644
--- a/pkg/sentry/fsimpl/kernfs/BUILD
+++ b/pkg/sentry/fsimpl/kernfs/BUILD
@@ -119,7 +119,6 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/usermem",
],
)
@@ -137,6 +136,7 @@ go_test(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/errors/linuxerr",
+ "//pkg/fspath",
"//pkg/log",
"//pkg/refs",
"//pkg/refsvfs2",
@@ -144,7 +144,6 @@ go_test(
"//pkg/sentry/fsimpl/testutil",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
- "//pkg/syserror",
"//pkg/usermem",
"@com_github_google_go_cmp//cmp:go_default_library",
],
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
index 8b008dc10..7db1473c4 100644
--- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -24,7 +24,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -99,7 +98,7 @@ func NewGenericDirectoryFD(m *vfs.Mount, d *Dentry, children *OrderedChildren, l
func (fd *GenericDirectoryFD) Init(children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions, fdOpts GenericDirectoryFDOptions) error {
if vfs.AccessTypesForOpenFlags(opts)&vfs.MayWrite != 0 {
// Can't open directories for writing.
- return syserror.EISDIR
+ return linuxerr.EISDIR
}
fd.LockFD.Init(locks)
fd.seekEnd = fdOpts.SeekEnd
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index a97473f7d..363ebc466 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
)
// stepExistingLocked resolves rp.Component() in parent directory vfsd.
@@ -224,7 +223,7 @@ func checkCreateLocked(ctx context.Context, creds *auth.Credentials, name string
return linuxerr.EEXIST
}
if parent.VFSDentry().IsDead() {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if err := parent.inode.CheckPermissions(ctx, creds, vfs.MayWrite); err != nil {
return err
@@ -241,7 +240,7 @@ func checkDeleteLocked(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry) er
return linuxerr.EBUSY
}
if parent.vfsd.IsDead() {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if err := parent.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
@@ -362,7 +361,7 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
return err
}
if rp.MustBeDir() {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if rp.Mount() != vd.Mount() {
return linuxerr.EXDEV
@@ -443,7 +442,7 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
return err
}
if rp.MustBeDir() {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if err := rp.Mount().CheckBeginWrite(); err != nil {
return err
@@ -509,7 +508,7 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
defer unlock()
if rp.Done() {
if rp.MustBeDir() {
- return nil, syserror.EISDIR
+ return nil, linuxerr.EISDIR
}
if mustCreate {
return nil, linuxerr.EEXIST
@@ -536,11 +535,11 @@ afterTrailingSymlink:
}
// Reject attempts to open directories with O_CREAT.
if rp.MustBeDir() {
- return nil, syserror.EISDIR
+ return nil, linuxerr.EISDIR
}
pc := rp.Component()
if pc == "." || pc == ".." {
- return nil, syserror.EISDIR
+ return nil, linuxerr.EISDIR
}
if len(pc) > linux.NAME_MAX {
return nil, linuxerr.ENAMETOOLONG
@@ -861,7 +860,7 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ
return err
}
if rp.MustBeDir() {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if err := rp.Mount().CheckBeginWrite(); err != nil {
return err
@@ -895,7 +894,7 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
return err
}
if d.isDir() {
- return syserror.EISDIR
+ return linuxerr.EISDIR
}
virtfs := rp.VirtualFilesystem()
parentDentry := d.parent
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index a42fc79b4..b96dc9ef7 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -26,7 +26,6 @@ import (
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
// InodeNoopRefCount partially implements the Inode interface, specifically the
@@ -234,6 +233,11 @@ func (a *InodeAttrs) Mode() linux.FileMode {
return linux.FileMode(atomic.LoadUint32(&a.mode))
}
+// Links returns the link count.
+func (a *InodeAttrs) Links() uint32 {
+ return atomic.LoadUint32(&a.nlink)
+}
+
// TouchAtime updates a.atime to the current time.
func (a *InodeAttrs) TouchAtime(ctx context.Context, mnt *vfs.Mount) {
if mnt.Flags.NoATime || mnt.ReadOnly() {
@@ -289,7 +293,7 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut
return linuxerr.EPERM
}
if opts.Stat.Mask&linux.STATX_SIZE != 0 && a.Mode().IsDir() {
- return syserror.EISDIR
+ return linuxerr.EISDIR
}
if err := vfs.CheckSetStat(ctx, creds, &opts, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil {
return err
@@ -475,7 +479,7 @@ func (o *OrderedChildren) Lookup(ctx context.Context, name string) (Inode, error
s, ok := o.set[name]
if !ok {
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
s.inode.IncRef() // This ref is passed to the dentry upon creation via Init.
@@ -502,6 +506,30 @@ func (o *OrderedChildren) Insert(name string, child Inode) error {
return o.insert(name, child, false)
}
+// Inserter is like Insert, but obtains the child to insert by calling
+// makeChild. makeChild is only called if the insert will succeed. This allows
+// the caller to atomically check and insert a child without having to
+// clean up the child on failure.
+func (o *OrderedChildren) Inserter(name string, makeChild func() Inode) (Inode, error) {
+ o.mu.Lock()
+ defer o.mu.Unlock()
+ if _, ok := o.set[name]; ok {
+ return nil, linuxerr.EEXIST
+ }
+
+ // Note: We must not fail after we call makeChild().
+
+ child := makeChild()
+ s := &slot{
+ name: name,
+ inode: child,
+ static: false,
+ }
+ o.order.PushBack(s)
+ o.set[name] = s
+ return child, nil
+}
+
// insert inserts child into o.
//
// Precondition: Caller must be holding a ref on child if static is true.
@@ -559,7 +587,7 @@ func (o *OrderedChildren) replaceChildLocked(ctx context.Context, name string, n
func (o *OrderedChildren) checkExistingLocked(name string, child Inode) error {
s, ok := o.set[name]
if !ok {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if s.inode != child {
panic(fmt.Sprintf("Inode doesn't match what kernfs thinks! OrderedChild: %+v, kernfs: %+v", s.inode, child))
@@ -746,5 +774,5 @@ type InodeNoStatFS struct{}
// StatFS implements Inode.StatFS.
func (*InodeNoStatFS) StatFS(context.Context, *vfs.Filesystem) (linux.Statfs, error) {
- return linux.Statfs{}, syserror.ENOSYS
+ return linux.Statfs{}, linuxerr.ENOSYS
}
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index 0e2867d49..544698694 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -61,6 +61,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -542,6 +543,63 @@ func (d *Dentry) FSLocalPath() string {
return b.String()
}
+// WalkDentryTree traverses p in the dentry tree for this filesystem. Note that
+// this only traverses the dentry tree and is not a general path traversal. No
+// symlinks and dynamic children are resolved, and no permission checks are
+// performed. The caller is responsible for ensuring the returned Dentry exists
+// for an appropriate lifetime.
+//
+// p is interpreted starting at d, and may be absolute or relative (absolute vs
+// relative paths both refer to the same target here, since p is absolute from
+// d). p may contain "." and "..", but will not allow traversal above d (similar
+// to ".." at the root dentry).
+//
+// This is useful for filesystem internals, where the filesystem may not be
+// mounted yet. For a mounted filesystem, use GetDentryAt.
+func (d *Dentry) WalkDentryTree(ctx context.Context, vfsObj *vfs.VirtualFilesystem, p fspath.Path) (*Dentry, error) {
+ d.fs.mu.RLock()
+ defer d.fs.processDeferredDecRefs(ctx)
+ defer d.fs.mu.RUnlock()
+
+ target := d
+
+ for pit := p.Begin; pit.Ok(); pit = pit.Next() {
+ pc := pit.String()
+
+ switch {
+ case target == nil:
+ return nil, linuxerr.ENOENT
+ case pc == ".":
+ // No-op, consume component and continue.
+ case pc == "..":
+ if target == d {
+ // Don't let .. traverse above the start point of the walk.
+ continue
+ }
+ target = target.parent
+ // Parent doesn't need revalidation since we revalidated it on the
+ // way to the child, and we're still holding fs.mu.
+ default:
+ var err error
+
+ d.dirMu.Lock()
+ target, err = d.fs.revalidateChildLocked(ctx, vfsObj, target, pc, target.children[pc])
+ d.dirMu.Unlock()
+
+ if err != nil {
+ return nil, err
+ }
+ }
+ }
+
+ if target == nil {
+ return nil, linuxerr.ENOENT
+ }
+
+ target.IncRef()
+ return target, nil
+}
+
// The Inode interface maps filesystem-level operations that operate on paths to
// equivalent operations on specific filesystem nodes.
//
@@ -667,12 +725,15 @@ type inodeDirectory interface {
// RmDir removes an empty child directory from this directory
// inode. Implementations must update the parent directory's link count,
// if required. Implementations are not responsible for checking that child
- // is a directory, checking for an empty directory.
+ // is a directory, or checking for an empty directory.
RmDir(ctx context.Context, name string, child Inode) error
// Rename is called on the source directory containing an inode being
- // renamed. child should point to the resolved child in the source
- // directory.
+ // renamed. child points to the resolved child in the source directory.
+ // dstDir is guaranteed to be a directory inode.
+ //
+ // On a successful call to Rename, the caller updates the dentry tree to
+ // reflect the name change.
//
// Precondition: Caller must serialize concurrent calls to Rename.
Rename(ctx context.Context, oldname, newname string, child, dstDir Inode) error
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
index 609887943..a2aba9321 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
@@ -346,3 +347,63 @@ func TestDirFDIterDirents(t *testing.T) {
"file1": linux.DT_REG,
})
}
+
+func TestDirWalkDentryTree(t *testing.T) {
+ sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newDir(ctx, creds, 0755, map[string]kernfs.Inode{
+ "dir1": fs.newDir(ctx, creds, 0755, nil),
+ "dir2": fs.newDir(ctx, creds, 0755, map[string]kernfs.Inode{
+ "file1": fs.newFile(ctx, creds, staticFileContent),
+ "dir3": fs.newDir(ctx, creds, 0755, nil),
+ }),
+ })
+ })
+ defer sys.Destroy()
+
+ testWalk := func(from *kernfs.Dentry, getDentryPath, walkPath string, expectedErr error) {
+ var d *kernfs.Dentry
+ if getDentryPath != "" {
+ pop := sys.PathOpAtRoot(getDentryPath)
+ vd := sys.GetDentryOrDie(pop)
+ defer vd.DecRef(sys.Ctx)
+ d = vd.Dentry().Impl().(*kernfs.Dentry)
+ }
+
+ match, err := from.WalkDentryTree(sys.Ctx, sys.VFS, fspath.Parse(walkPath))
+ if err == nil {
+ defer match.DecRef(sys.Ctx)
+ }
+
+ if err != expectedErr {
+ t.Fatalf("WalkDentryTree from %q to %q (with expected error: %v) unexpected error, want: %v, got: %v", from.FSLocalPath(), walkPath, expectedErr, expectedErr, err)
+ }
+ if expectedErr != nil {
+ return
+ }
+
+ if d != match {
+ t.Fatalf("WalkDentryTree from %q to %q (with expected error: %v) found unexpected dentry; want: %v, got: %v", from.FSLocalPath(), walkPath, expectedErr, d, match)
+ }
+ }
+
+ rootD := sys.Root.Dentry().Impl().(*kernfs.Dentry)
+
+ testWalk(rootD, "dir1", "/dir1", nil)
+ testWalk(rootD, "", "/dir-non-existent", linuxerr.ENOENT)
+ testWalk(rootD, "", "/dir1/child-non-existent", linuxerr.ENOENT)
+ testWalk(rootD, "", "/dir2/inner-non-existent/dir3", linuxerr.ENOENT)
+
+ testWalk(rootD, "dir2/dir3", "/dir2/../dir2/dir3", nil)
+ testWalk(rootD, "dir2/dir3", "/dir2/././dir3", nil)
+ testWalk(rootD, "dir2/dir3", "/dir2/././dir3/.././dir3", nil)
+
+ pop := sys.PathOpAtRoot("dir2")
+ dir2VD := sys.GetDentryOrDie(pop)
+ defer dir2VD.DecRef(sys.Ctx)
+ dir2D := dir2VD.Dentry().Impl().(*kernfs.Dentry)
+
+ testWalk(dir2D, "dir2/dir3", "/dir3", nil)
+ testWalk(dir2D, "dir2/dir3", "/../../../dir3", nil)
+ testWalk(dir2D, "dir2/file1", "/file1", nil)
+ testWalk(dir2D, "dir2/file1", "file1", nil)
+}
diff --git a/pkg/sentry/fsimpl/overlay/BUILD b/pkg/sentry/fsimpl/overlay/BUILD
index ed730e215..d16dfef9b 100644
--- a/pkg/sentry/fsimpl/overlay/BUILD
+++ b/pkg/sentry/fsimpl/overlay/BUILD
@@ -42,7 +42,6 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go
index 1f85a1f0d..520487066 100644
--- a/pkg/sentry/fsimpl/overlay/copy_up.go
+++ b/pkg/sentry/fsimpl/overlay/copy_up.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
)
func (d *dentry) isCopiedUp() bool {
@@ -37,6 +36,10 @@ func (d *dentry) isCopiedUp() bool {
//
// Preconditions: filesystem.renameMu must be locked.
func (d *dentry) copyUpLocked(ctx context.Context) error {
+ return d.copyUpMaybeSyntheticMountpointLocked(ctx, false /* forSyntheticMountpoint */)
+}
+
+func (d *dentry) copyUpMaybeSyntheticMountpointLocked(ctx context.Context, forSyntheticMountpoint bool) error {
// Fast path.
if d.isCopiedUp() {
return nil
@@ -60,7 +63,7 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
// d is a filesystem root with no upper layer.
return linuxerr.EROFS
}
- if err := d.parent.copyUpLocked(ctx); err != nil {
+ if err := d.parent.copyUpMaybeSyntheticMountpointLocked(ctx, forSyntheticMountpoint); err != nil {
return err
}
@@ -72,7 +75,7 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
}
if d.vfsd.IsDead() {
// Raced with deletion of d.
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
// Obtain settable timestamps from the lower layer.
@@ -169,7 +172,8 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
case linux.S_IFDIR:
if err := vfsObj.MkdirAt(ctx, d.fs.creds, &newpop, &vfs.MkdirOptions{
- Mode: linux.FileMode(d.mode &^ linux.S_IFMT),
+ Mode: linux.FileMode(d.mode &^ linux.S_IFMT),
+ ForSyntheticMountpoint: forSyntheticMountpoint,
}); err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
index 5e89928c5..3b3dcf836 100644
--- a/pkg/sentry/fsimpl/overlay/filesystem.go
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -28,7 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
// _OVL_XATTR_PREFIX is an extended attribute key prefix to identify overlayfs
@@ -314,7 +313,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str
}
if !topLookupLayer.existsInOverlay() {
child.destroyLocked(ctx)
- return nil, topLookupLayer, syserror.ENOENT
+ return nil, topLookupLayer, linuxerr.ENOENT
}
// Device and inode numbers were copied from the topmost layer above. Remap
@@ -463,13 +462,21 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath,
return d, nil
}
+type createType int
+
+const (
+ createNonDirectory createType = iota
+ createDirectory
+ createSyntheticMountpoint
+)
+
// doCreateAt checks that creating a file at rp is permitted, then invokes
// create to do so.
//
// Preconditions:
// * !rp.Done().
// * For the final path component in rp, !rp.ShouldFollowSymlink().
-func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, create func(parent *dentry, name string, haveUpperWhiteout bool) error) error {
+func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, ct createType, create func(parent *dentry, name string, haveUpperWhiteout bool) error) error {
var ds *[]*dentry
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
@@ -483,7 +490,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
return linuxerr.EEXIST
}
if parent.vfsd.IsDead() {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
@@ -505,8 +512,8 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
return linuxerr.EEXIST
}
- if !dir && rp.MustBeDir() {
- return syserror.ENOENT
+ if ct == createNonDirectory && rp.MustBeDir() {
+ return linuxerr.ENOENT
}
mnt := rp.Mount()
@@ -519,7 +526,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
}
// Ensure that the parent directory is copied-up so that we can create the
// new file in the upper layer.
- if err := parent.copyUpLocked(ctx); err != nil {
+ if err := parent.copyUpMaybeSyntheticMountpointLocked(ctx, ct == createSyntheticMountpoint); err != nil {
return err
}
@@ -530,7 +537,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
parent.dirents = nil
ev := linux.IN_CREATE
- if dir {
+ if ct != createNonDirectory {
ev |= linux.IN_ISDIR
}
parent.watches.Notify(ctx, name, uint32(ev), 0 /* cookie */, vfs.InodeEvent, false /* unlinked */)
@@ -619,7 +626,7 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa
// LinkAt implements vfs.FilesystemImpl.LinkAt.
func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
- return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
+ return fs.doCreateAt(ctx, rp, createNonDirectory, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
if rp.Mount() != vd.Mount() {
return linuxerr.EXDEV
}
@@ -672,7 +679,11 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
- return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
+ ct := createDirectory
+ if opts.ForSyntheticMountpoint {
+ ct = createSyntheticMountpoint
+ }
+ return fs.doCreateAt(ctx, rp, ct, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
vfsObj := fs.vfsfs.VirtualFilesystem()
pop := vfs.PathOperation{
Root: parent.upperVD,
@@ -723,7 +734,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
// MknodAt implements vfs.FilesystemImpl.MknodAt.
func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
- return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
+ return fs.doCreateAt(ctx, rp, createNonDirectory, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
// Disallow attempts to create whiteouts.
if opts.Mode&linux.S_IFMT == linux.S_IFCHR && opts.DevMajor == 0 && opts.DevMinor == 0 {
return linuxerr.EPERM
@@ -780,7 +791,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
start := rp.Start().Impl().(*dentry)
if rp.Done() {
if mayCreate && rp.MustBeDir() {
- return nil, syserror.EISDIR
+ return nil, linuxerr.EISDIR
}
if mustCreate {
return nil, linuxerr.EEXIST
@@ -807,7 +818,7 @@ afterTrailingSymlink:
}
// Reject attempts to open directories with O_CREAT.
if mayCreate && rp.MustBeDir() {
- return nil, syserror.EISDIR
+ return nil, linuxerr.EISDIR
}
// Determine whether or not we need to create a file.
parent.dirMu.Lock()
@@ -865,11 +876,11 @@ func (d *dentry) openCopiedUp(ctx context.Context, rp *vfs.ResolvingPath, opts *
if ftype == linux.S_IFDIR {
// Can't open directories with O_CREAT.
if opts.Flags&linux.O_CREAT != 0 {
- return nil, syserror.EISDIR
+ return nil, linuxerr.EISDIR
}
// Can't open directories writably.
if ats.MayWrite() {
- return nil, syserror.EISDIR
+ return nil, linuxerr.EISDIR
}
if opts.Flags&linux.O_DIRECT != 0 {
return nil, linuxerr.EINVAL
@@ -919,7 +930,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving
return nil, err
}
if parent.vfsd.IsDead() {
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
mnt := rp.Mount()
if err := mnt.CheckBeginWrite(); err != nil {
@@ -1086,7 +1097,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
defer newParent.dirMu.Unlock()
}
if newParent.vfsd.IsDead() {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
var (
replaced *dentry
@@ -1105,7 +1116,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
replacedVFSD = &replaced.vfsd
if replaced.isDir() {
if !renamed.isDir() {
- return syserror.EISDIR
+ return linuxerr.EISDIR
}
if genericIsAncestorDentry(replaced, renamed) {
return linuxerr.ENOTEMPTY
@@ -1477,7 +1488,7 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu
// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
- return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
+ return fs.doCreateAt(ctx, rp, createNonDirectory, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
vfsObj := fs.vfsfs.VirtualFilesystem()
pop := vfs.PathOperation{
Root: parent.upperVD,
@@ -1533,7 +1544,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
defer rp.Mount().EndWrite()
name := rp.Component()
if name == "." || name == ".." {
- return syserror.EISDIR
+ return linuxerr.EISDIR
}
if rp.MustBeDir() {
return linuxerr.ENOTDIR
@@ -1557,7 +1568,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
return err
}
if child.isDir() {
- return syserror.EISDIR
+ return linuxerr.EISDIR
}
if err := parent.mayDelete(rp.Credentials(), child); err != nil {
return err
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
index 1d3d2d95f..95cfbdc42 100644
--- a/pkg/sentry/fsimpl/proc/BUILD
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -102,7 +102,6 @@ go_library(
"//pkg/sentry/usage",
"//pkg/sentry/vfs",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/tcpip/header",
"//pkg/tcpip/network/ipv4",
"//pkg/usermem",
diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go
index d99f90b36..e04ae6660 100644
--- a/pkg/sentry/fsimpl/proc/subtasks.go
+++ b/pkg/sentry/fsimpl/proc/subtasks.go
@@ -25,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
)
// subtasksInode represents the inode for /proc/[pid]/task/ directory.
@@ -71,15 +70,15 @@ func (fs *filesystem) newSubtasks(ctx context.Context, task *kernel.Task, pidns
func (i *subtasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, error) {
tid, err := strconv.ParseUint(name, 10, 32)
if err != nil {
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
subTask := i.pidns.TaskWithID(kernel.ThreadID(tid))
if subTask == nil {
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
if subTask.ThreadGroup() != i.task.ThreadGroup() {
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
return i.fs.newTaskInode(ctx, subTask, i.pidns, false, i.cgroupControllers)
}
@@ -88,7 +87,7 @@ func (i *subtasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode,
func (i *subtasksInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
tasks := i.task.ThreadGroup().MemberIDs(i.pidns)
if len(tasks) == 0 {
- return offset, syserror.ENOENT
+ return offset, linuxerr.ENOENT
}
if relOffset >= int64(len(tasks)) {
return offset, nil
@@ -124,7 +123,7 @@ type subtasksFD struct {
func (fd *subtasksFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
if fd.task.ExitState() >= kernel.TaskExitZombie {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
return fd.GenericDirectoryFD.IterDirents(ctx, cb)
}
@@ -132,7 +131,7 @@ func (fd *subtasksFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallbac
// Seek implements vfs.FileDescriptionImpl.Seek.
func (fd *subtasksFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
if fd.task.ExitState() >= kernel.TaskExitZombie {
- return 0, syserror.ENOENT
+ return 0, linuxerr.ENOENT
}
return fd.GenericDirectoryFD.Seek(ctx, offset, whence)
}
@@ -140,7 +139,7 @@ func (fd *subtasksFD) Seek(ctx context.Context, offset int64, whence int32) (int
// Stat implements vfs.FileDescriptionImpl.Stat.
func (fd *subtasksFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
if fd.task.ExitState() >= kernel.TaskExitZombie {
- return linux.Statx{}, syserror.ENOENT
+ return linux.Statx{}, linuxerr.ENOENT
}
return fd.GenericDirectoryFD.Stat(ctx, opts)
}
@@ -148,7 +147,7 @@ func (fd *subtasksFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Sta
// SetStat implements vfs.FileDescriptionImpl.SetStat.
func (fd *subtasksFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
if fd.task.ExitState() >= kernel.TaskExitZombie {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
return fd.GenericDirectoryFD.SetStat(ctx, opts)
}
diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go
index dfc0a924e..5c6412fc0 100644
--- a/pkg/sentry/fsimpl/proc/task_fds.go
+++ b/pkg/sentry/fsimpl/proc/task_fds.go
@@ -22,11 +22,11 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
)
func getTaskFD(t *kernel.Task, fd int32) (*vfs.FileDescription, kernel.FDFlags) {
@@ -142,11 +142,11 @@ func (i *fdDirInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.Ite
func (i *fdDirInode) Lookup(ctx context.Context, name string) (kernfs.Inode, error) {
fdInt, err := strconv.ParseInt(name, 10, 32)
if err != nil {
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
fd := int32(fdInt)
if !taskFDExists(ctx, i.fs, i.task, fd) {
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
return i.fs.newFDSymlink(ctx, i.task, fd, i.fs.NextIno()), nil
}
@@ -218,7 +218,7 @@ func (fs *filesystem) newFDSymlink(ctx context.Context, task *kernel.Task, fd in
func (s *fdSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) {
file, _ := getTaskFD(s.task, s.fd)
if file == nil {
- return "", syserror.ENOENT
+ return "", linuxerr.ENOENT
}
defer s.fs.SafeDecRefFD(ctx, file)
root := vfs.RootFromContext(ctx)
@@ -231,7 +231,7 @@ func (s *fdSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error)
func (s *fdSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error) {
file, _ := getTaskFD(s.task, s.fd)
if file == nil {
- return vfs.VirtualDentry{}, "", syserror.ENOENT
+ return vfs.VirtualDentry{}, "", linuxerr.ENOENT
}
defer s.fs.SafeDecRefFD(ctx, file)
vd := file.VirtualDentry()
@@ -278,11 +278,11 @@ func (fs *filesystem) newFDInfoDirInode(ctx context.Context, task *kernel.Task)
func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (kernfs.Inode, error) {
fdInt, err := strconv.ParseInt(name, 10, 32)
if err != nil {
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
fd := int32(fdInt)
if !taskFDExists(ctx, i.fs, i.task, fd) {
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
data := &fdInfoData{
fs: i.fs,
@@ -330,7 +330,7 @@ var _ dynamicInode = (*fdInfoData)(nil)
func (d *fdInfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
file, descriptorFlags := getTaskFD(d.task, d.fd)
if file == nil {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
defer d.fs.SafeDecRefFD(ctx, file)
// TODO(b/121266871): Include pos, locks, and other data. For now we only
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index 0ce3ed797..d3f9cf489 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -33,7 +33,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -41,7 +40,7 @@ import (
// Linux 3.18, the limit is five lines." - user_namespaces(7)
const maxIDMapLines = 5
-// mm gets the kernel task's MemoryManager. No additional reference is taken on
+// getMM gets the kernel task's MemoryManager. No additional reference is taken on
// mm here. This is safe because MemoryManager.destroy is required to leave the
// MemoryManager in a state where it's still usable as a DynamicBytesSource.
func getMM(task *kernel.Task) *mm.MemoryManager {
@@ -491,7 +490,7 @@ func (fd *memFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64
return int64(n), nil
}
if readErr != nil {
- return 0, syserror.EIO
+ return 0, linuxerr.EIO
}
return 0, nil
}
@@ -609,12 +608,10 @@ func (s *taskStatData) Generate(ctx context.Context, buf *bytes.Buffer) error {
fmt.Fprintf(buf, "%d ", linux.ClockTFromDuration(s.task.StartTime().Sub(s.task.Kernel().Timekeeper().BootTime())))
var vss, rss uint64
- s.task.WithMuLocked(func(t *kernel.Task) {
- if mm := t.MemoryManager(); mm != nil {
- vss = mm.VirtualMemorySize()
- rss = mm.ResidentSetSize()
- }
- })
+ if mm := getMM(s.task); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ }
fmt.Fprintf(buf, "%d %d ", vss, rss/hostarch.PageSize)
// rsslim.
@@ -650,13 +647,10 @@ var _ dynamicInode = (*statmData)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error {
var vss, rss uint64
- s.task.WithMuLocked(func(t *kernel.Task) {
- if mm := t.MemoryManager(); mm != nil {
- vss = mm.VirtualMemorySize()
- rss = mm.ResidentSetSize()
- }
- })
-
+ if mm := getMM(s.task); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ }
fmt.Fprintf(buf, "%d %d 0 0 0 0 0\n", vss/hostarch.PageSize, rss/hostarch.PageSize)
return nil
}
@@ -780,12 +774,12 @@ func (s *statusFD) Generate(ctx context.Context, buf *bytes.Buffer) error {
if fdTable := t.FDTable(); fdTable != nil {
fds = fdTable.CurrentMaxFDs()
}
- if mm := t.MemoryManager(); mm != nil {
- vss = mm.VirtualMemorySize()
- rss = mm.ResidentSetSize()
- data = mm.VirtualDataSize()
- }
})
+ if mm := getMM(s.task); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ data = mm.VirtualDataSize()
+ }
// Filesystem user/group IDs aren't implemented; effective UID/GID are used
// instead.
fmt.Fprintf(buf, "Uid:\t%d\t%d\t%d\t%d\n", ruid, euid, suid, euid)
@@ -946,25 +940,17 @@ func (s *exeSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDent
return vfs.VirtualDentry{}, "", err
}
- var err error
- var exec fsbridge.File
- s.task.WithMuLocked(func(t *kernel.Task) {
- mm := t.MemoryManager()
- if mm == nil {
- err = linuxerr.EACCES
- return
- }
+ mm := getMM(s.task)
+ if mm == nil {
+ return vfs.VirtualDentry{}, "", linuxerr.EACCES
+ }
- // The MemoryManager may be destroyed, in which case
- // MemoryManager.destroy will simply set the executable to nil
- // (with locks held).
- exec = mm.Executable()
- if exec == nil {
- err = linuxerr.ESRCH
- }
- })
- if err != nil {
- return vfs.VirtualDentry{}, "", err
+ // The MemoryManager may be destroyed, in which case
+ // MemoryManager.destroy will simply set the executable to nil
+ // (with locks held).
+ exec := mm.Executable()
+ if exec == nil {
+ return vfs.VirtualDentry{}, "", linuxerr.ESRCH
}
defer exec.DecRef(ctx)
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index cf905fae4..7b0be9c14 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -21,11 +21,11 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
)
const (
@@ -116,12 +116,12 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err
case threadSelfName:
return i.newThreadSelfSymlink(ctx, root), nil
}
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
task := i.pidns.TaskWithID(kernel.ThreadID(tid))
if task == nil {
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
return i.fs.newTaskInode(ctx, task, i.pidns, true, i.fakeCgroupControllers)
@@ -268,6 +268,6 @@ func cpuInfoData(k *kernel.Kernel) string {
return buf.String()
}
-func shmData(v uint64) dynamicInode {
+func ipcData(v uint64) dynamicInode {
return newStaticFile(strconv.FormatUint(v, 10))
}
diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go
index 03bed22a3..4d3a2f7e6 100644
--- a/pkg/sentry/fsimpl/proc/tasks_files.go
+++ b/pkg/sentry/fsimpl/proc/tasks_files.go
@@ -29,7 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
)
// +stateify savable
@@ -58,7 +57,7 @@ func (s *selfSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error
}
tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup())
if tgid == 0 {
- return "", syserror.ENOENT
+ return "", linuxerr.ENOENT
}
return strconv.FormatUint(uint64(tgid), 10), nil
}
@@ -100,7 +99,7 @@ func (s *threadSelfSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string,
tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup())
tid := s.pidns.IDOfTask(t)
if tid == 0 || tgid == 0 {
- return "", syserror.ENOENT
+ return "", linuxerr.ENOENT
}
return fmt.Sprintf("%d/task/%d", tgid, tid), nil
}
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index 99f64a9d8..82e2857b3 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -47,9 +47,12 @@ func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k *
"kernel": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
"hostname": fs.newInode(ctx, root, 0444, &hostnameData{}),
"sem": fs.newInode(ctx, root, 0444, newStaticFile(fmt.Sprintf("%d\t%d\t%d\t%d\n", linux.SEMMSL, linux.SEMMNS, linux.SEMOPM, linux.SEMMNI))),
- "shmall": fs.newInode(ctx, root, 0444, shmData(linux.SHMALL)),
- "shmmax": fs.newInode(ctx, root, 0444, shmData(linux.SHMMAX)),
- "shmmni": fs.newInode(ctx, root, 0444, shmData(linux.SHMMNI)),
+ "shmall": fs.newInode(ctx, root, 0444, ipcData(linux.SHMALL)),
+ "shmmax": fs.newInode(ctx, root, 0444, ipcData(linux.SHMMAX)),
+ "shmmni": fs.newInode(ctx, root, 0444, ipcData(linux.SHMMNI)),
+ "msgmni": fs.newInode(ctx, root, 0444, ipcData(linux.MSGMNI)),
+ "msgmax": fs.newInode(ctx, root, 0444, ipcData(linux.MSGMAX)),
+ "msgmnb": fs.newInode(ctx, root, 0444, ipcData(linux.MSGMNB)),
"yama": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
"ptrace_scope": fs.newYAMAPtraceScopeFile(ctx, k, root),
}),
diff --git a/pkg/sentry/fsimpl/signalfd/BUILD b/pkg/sentry/fsimpl/signalfd/BUILD
index adb610213..403c6f254 100644
--- a/pkg/sentry/fsimpl/signalfd/BUILD
+++ b/pkg/sentry/fsimpl/signalfd/BUILD
@@ -9,10 +9,10 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/sentry/kernel",
"//pkg/sentry/vfs",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go
index a7f5928b7..bdb03ef96 100644
--- a/pkg/sentry/fsimpl/signalfd/signalfd.go
+++ b/pkg/sentry/fsimpl/signalfd/signalfd.go
@@ -18,10 +18,10 @@ package signalfd
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -91,7 +91,7 @@ func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequen
info, err := sfd.target.Sigtimedwait(sfd.Mask(), 0)
if err != nil {
// There must be no signal available.
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
// Copy out the signal info using the specified format.
diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD
index 1af0a5cbc..ab21f028e 100644
--- a/pkg/sentry/fsimpl/sys/BUILD
+++ b/pkg/sentry/fsimpl/sys/BUILD
@@ -36,7 +36,6 @@ go_library(
"//pkg/sentry/kernel/auth",
"//pkg/sentry/memmap",
"//pkg/sentry/vfs",
- "//pkg/syserror",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
index f322d2747..7fcb2d26b 100644
--- a/pkg/sentry/fsimpl/sys/sys.go
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -84,6 +84,18 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
fs.MaxCachedDentries = maxCachedDentries
fs.VFSFilesystem().Init(vfsObj, &fsType, fs)
+ k := kernel.KernelFromContext(ctx)
+ fsDirChildren := make(map[string]kernfs.Inode)
+ // Create an empty directory to serve as the mount point for cgroupfs when
+ // cgroups are available. This emulates Linux behaviour, see
+ // kernel/cgroup.c:cgroup_init(). Note that in Linux, userspace (typically
+ // the init process) is ultimately responsible for actually mounting
+ // cgroupfs, but the kernel creates the mountpoint. For the sentry, the
+ // launcher mounts cgroupfs.
+ if k.CgroupRegistry() != nil {
+ fsDirChildren["cgroup"] = fs.newDir(ctx, creds, defaultSysDirMode, nil)
+ }
+
root := fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{
"block": fs.newDir(ctx, creds, defaultSysDirMode, nil),
"bus": fs.newDir(ctx, creds, defaultSysDirMode, nil),
@@ -97,7 +109,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}),
}),
"firmware": fs.newDir(ctx, creds, defaultSysDirMode, nil),
- "fs": fs.newDir(ctx, creds, defaultSysDirMode, nil),
+ "fs": fs.newDir(ctx, creds, defaultSysDirMode, fsDirChildren),
"kernel": kernelDir(ctx, fs, creds),
"module": fs.newDir(ctx, creds, defaultSysDirMode, nil),
"power": fs.newDir(ctx, creds, defaultSysDirMode, nil),
diff --git a/pkg/sentry/fsimpl/sys/sys_test.go b/pkg/sentry/fsimpl/sys/sys_test.go
index 0a0d914cc..0c46a3a13 100644
--- a/pkg/sentry/fsimpl/sys/sys_test.go
+++ b/pkg/sentry/fsimpl/sys/sys_test.go
@@ -87,3 +87,17 @@ func TestSysRootContainsExpectedEntries(t *testing.T) {
"power": linux.DT_DIR,
})
}
+
+func TestCgroupMountpointExists(t *testing.T) {
+ // Note: The mountpoint is only created if cgroups are available. This is
+ // the VFS2 implementation of sysfs and the test runs with VFS2 enabled, so
+ // we expect to see the mount point unconditionally.
+ s := newTestSystem(t)
+ defer s.Destroy()
+ pop := s.PathOpAtRoot("/fs")
+ s.AssertAllDirentTypes(s.ListDirents(pop), map[string]testutil.DirentType{
+ "cgroup": linux.DT_DIR,
+ })
+ pop = s.PathOpAtRoot("/fs/cgroup")
+ s.AssertAllDirentTypes(s.ListDirents(pop), map[string]testutil.DirentType{ /*empty*/ })
+}
diff --git a/pkg/sentry/fsimpl/timerfd/BUILD b/pkg/sentry/fsimpl/timerfd/BUILD
index e6980a314..2b83d7d9a 100644
--- a/pkg/sentry/fsimpl/timerfd/BUILD
+++ b/pkg/sentry/fsimpl/timerfd/BUILD
@@ -12,7 +12,6 @@ go_library(
"//pkg/hostarch",
"//pkg/sentry/kernel/time",
"//pkg/sentry/vfs",
- "//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fsimpl/timerfd/timerfd.go b/pkg/sentry/fsimpl/timerfd/timerfd.go
index 655a1c76a..68b785791 100644
--- a/pkg/sentry/fsimpl/timerfd/timerfd.go
+++ b/pkg/sentry/fsimpl/timerfd/timerfd.go
@@ -23,7 +23,6 @@ import (
"gvisor.dev/gvisor/pkg/hostarch"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -82,7 +81,7 @@ func (tfd *TimerFileDescription) Read(ctx context.Context, dst usermem.IOSequenc
}
return sizeofUint64, nil
}
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
// Clock returns the timer fd's Clock.
diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD
index dc8b9bfeb..94486bb63 100644
--- a/pkg/sentry/fsimpl/tmpfs/BUILD
+++ b/pkg/sentry/fsimpl/tmpfs/BUILD
@@ -82,7 +82,6 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/sentry/vfs/memxattr",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/usermem",
],
)
@@ -125,7 +124,6 @@ go_test(
"//pkg/sentry/fs/lock",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
- "//pkg/syserror",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index 8b04df038..e067f136e 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
)
// Sync implements vfs.FilesystemImpl.Sync.
@@ -75,7 +74,7 @@ afterSymlink:
}
child, ok := dir.childMap[name]
if !ok {
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
if err := rp.CheckMount(ctx, &child.vfsd); err != nil {
return nil, err
@@ -171,12 +170,12 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
return linuxerr.EEXIST
}
if !dir && rp.MustBeDir() {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
// tmpfs never calls VFS.InvalidateDentry(), so parentDir.dentry can only
// be dead if it was deleted.
if parentDir.dentry.vfsd.IsDead() {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
mnt := rp.Mount()
if err := mnt.CheckBeginWrite(); err != nil {
@@ -258,7 +257,7 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
return err
}
if i.nlink == 0 {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if i.nlink == maxLinks {
return linuxerr.EMLINK
@@ -345,7 +344,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if rp.Done() {
// Reject attempts to open mount root directory with O_CREAT.
if rp.MustBeDir() {
- return nil, syserror.EISDIR
+ return nil, linuxerr.EISDIR
}
if mustCreate {
return nil, linuxerr.EEXIST
@@ -366,11 +365,11 @@ afterTrailingSymlink:
}
// Reject attempts to open directories with O_CREAT.
if rp.MustBeDir() {
- return nil, syserror.EISDIR
+ return nil, linuxerr.EISDIR
}
name := rp.Component()
if name == "." || name == ".." {
- return nil, syserror.EISDIR
+ return nil, linuxerr.EISDIR
}
if len(name) > linux.NAME_MAX {
return nil, linuxerr.ENAMETOOLONG
@@ -457,7 +456,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
case *directory:
// Can't open directories writably.
if ats&vfs.MayWrite != 0 {
- return nil, syserror.EISDIR
+ return nil, linuxerr.EISDIR
}
var fd directoryFD
fd.LockFD.Init(&d.inode.locks)
@@ -532,7 +531,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
renamed, ok := oldParentDir.childMap[oldName]
if !ok {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if err := oldParentDir.mayDelete(rp.Credentials(), renamed); err != nil {
return err
@@ -567,7 +566,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
replacedDir, ok := replaced.inode.impl.(*directory)
if ok {
if !renamed.inode.isDir() {
- return syserror.EISDIR
+ return linuxerr.EISDIR
}
if len(replacedDir.childMap) != 0 {
return linuxerr.ENOTEMPTY
@@ -588,7 +587,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
// tmpfs never calls VFS.InvalidateDentry(), so newParentDir.dentry can
// only be dead if it was deleted.
if newParentDir.dentry.vfsd.IsDead() {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
// Linux places this check before some of those above; we do it here for
@@ -654,7 +653,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
}
child, ok := parentDir.childMap[name]
if !ok {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if err := parentDir.mayDelete(rp.Credentials(), child); err != nil {
return err
@@ -754,17 +753,17 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
}
name := rp.Component()
if name == "." || name == ".." {
- return syserror.EISDIR
+ return linuxerr.EISDIR
}
child, ok := parentDir.childMap[name]
if !ok {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if err := parentDir.mayDelete(rp.Credentials(), child); err != nil {
return err
}
if child.inode.isDir() {
- return syserror.EISDIR
+ return linuxerr.EISDIR
}
if rp.MustBeDir() {
return linuxerr.ENOTDIR
diff --git a/pkg/sentry/fsimpl/tmpfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
index 418c7994e..99afd9817 100644
--- a/pkg/sentry/fsimpl/tmpfs/pipe_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
@@ -25,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -202,7 +201,7 @@ func checkEmpty(ctx context.Context, t *testing.T, fd *vfs.FileDescription) {
readData := make([]byte, 1)
dst := usermem.BytesIOSequence(readData)
bytesRead, err := fd.Read(ctx, dst, vfs.ReadOptions{})
- if err != syserror.ErrWouldBlock {
+ if err != linuxerr.ErrWouldBlock {
t.Fatalf("expected ErrWouldBlock reading from empty pipe %q, but got: %v", fileName, err)
}
if bytesRead != 0 {
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
index 4393cc13b..cb7711b39 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
@@ -21,10 +21,10 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -146,7 +146,7 @@ func TestLocks(t *testing.T) {
if err := fd.Impl().LockBSD(ctx, uid2, 0 /* ownerPID */, lock.ReadLock, nil); err != nil {
t.Fatalf("fd.Impl().LockBSD failed: err = %v", err)
}
- if got, want := fd.Impl().LockBSD(ctx, uid2, 0 /* ownerPID */, lock.WriteLock, nil), syserror.ErrWouldBlock; got != want {
+ if got, want := fd.Impl().LockBSD(ctx, uid2, 0 /* ownerPID */, lock.WriteLock, nil), linuxerr.ErrWouldBlock; got != want {
t.Fatalf("fd.Impl().LockBSD failed: got = %v, want = %v", got, want)
}
if err := fd.Impl().UnlockBSD(ctx, uid1); err != nil {
@@ -165,7 +165,7 @@ func TestLocks(t *testing.T) {
if err := fd.Impl().LockPOSIX(ctx, uid1, 0 /* ownerPID */, lock.WriteLock, lock.LockRange{Start: 0, End: 1}, nil); err != nil {
t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err)
}
- if got, want := fd.Impl().LockPOSIX(ctx, uid2, 0 /* ownerPID */, lock.ReadLock, lock.LockRange{Start: 0, End: 1}, nil), syserror.ErrWouldBlock; got != want {
+ if got, want := fd.Impl().LockPOSIX(ctx, uid2, 0 /* ownerPID */, lock.ReadLock, lock.LockRange{Start: 0, End: 1}, nil), linuxerr.ErrWouldBlock; got != want {
t.Fatalf("fd.Impl().LockPOSIX failed: got = %v, want = %v", got, want)
}
if err := fd.Impl().UnlockPOSIX(ctx, uid1, lock.LockRange{Start: 0, End: 1}); err != nil {
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index f2250c025..feafb06e4 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -44,7 +44,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sentry/vfs/memxattr"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
// Name is the default filesystem name.
@@ -556,7 +555,7 @@ func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs.
needsCtimeBump = true
}
case *directory:
- return syserror.EISDIR
+ return linuxerr.EISDIR
default:
return linuxerr.EINVAL
}
diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD
index 1d855234c..c12abdf33 100644
--- a/pkg/sentry/fsimpl/verity/BUILD
+++ b/pkg/sentry/fsimpl/verity/BUILD
@@ -1,10 +1,24 @@
load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
licenses(["notice"])
+go_template_instance(
+ name = "dentry_list",
+ out = "dentry_list.go",
+ package = "verity",
+ prefix = "dentry",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*dentry",
+ "Linker": "*dentry",
+ },
+)
+
go_library(
name = "verity",
srcs = [
+ "dentry_list.go",
"filesystem.go",
"save_restore.go",
"verity.go",
@@ -28,7 +42,6 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index 930016a3e..52d47994d 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -32,7 +32,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -67,40 +66,23 @@ func putDentrySlice(ds *[]*dentry) {
dentrySlicePool.Put(ds)
}
-// renameMuRUnlockAndCheckDrop calls fs.renameMu.RUnlock(), then calls
-// dentry.checkDropLocked on all dentries in *ds with fs.renameMu locked for
+// renameMuRUnlockAndCheckCaching calls fs.renameMu.RUnlock(), then calls
+// dentry.checkCachingLocked on all dentries in *ds with fs.renameMu locked for
// writing.
//
// ds is a pointer-to-pointer since defer evaluates its arguments immediately,
// but dentry slices are allocated lazily, and it's much easier to say "defer
-// fs.renameMuRUnlockAndCheckDrop(&ds)" than "defer func() {
-// fs.renameMuRUnlockAndCheckDrop(ds) }()" to work around this.
+// fs.renameMuRUnlockAndCheckCaching(&ds)" than "defer func() {
+// fs.renameMuRUnlockAndCheckCaching(ds) }()" to work around this.
// +checklocksrelease:fs.renameMu
-func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) {
+func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, ds **[]*dentry) {
fs.renameMu.RUnlock()
if *ds == nil {
return
}
- if len(**ds) != 0 {
- fs.renameMu.Lock()
- for _, d := range **ds {
- d.checkDropLocked(ctx)
- }
- fs.renameMu.Unlock()
- }
- putDentrySlice(*ds)
-}
-
-// +checklocksrelease:fs.renameMu
-func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) {
- if *ds == nil {
- fs.renameMu.Unlock()
- return
- }
for _, d := range **ds {
- d.checkDropLocked(ctx)
+ d.checkCachingLocked(ctx, false /* renameMuWriteLocked */)
}
- fs.renameMu.Unlock()
putDentrySlice(*ds)
}
@@ -166,7 +148,7 @@ afterSymlink:
// verifyChildLocked verifies the hash of child against the already verified
// hash of the parent to ensure the child is expected. verifyChild triggers a
// sentry panic if unexpected modifications to the file system are detected. In
-// ErrorOnViolation mode it returns a syserror instead.
+// ErrorOnViolation mode it returns a linuxerr instead.
//
// Preconditions:
// * fs.renameMu must be locked.
@@ -547,7 +529,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
if parent.verityEnabled() {
if _, ok := parent.childrenNames[name]; !ok {
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
}
@@ -595,23 +577,6 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
}
}
- // Clear the Merkle tree file if they are to be generated at runtime.
- // TODO(b/182315468): Optimize the Merkle tree generate process to
- // allow only updating certain files/directories.
- if fs.allowRuntimeEnable {
- childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{
- Root: childMerkleVD,
- Start: childMerkleVD,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDWR | linux.O_TRUNC,
- Mode: 0644,
- })
- if err != nil {
- return nil, err
- }
- childMerkleFD.DecRef(ctx)
- }
-
// The dentry needs to be cleaned up if any error occurs. IncRef will be
// called if a verity child dentry is successfully created.
defer childMerkleVD.DecRef(ctx)
@@ -718,7 +683,7 @@ func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds
}
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return err
@@ -730,7 +695,7 @@ func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds
func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return nil, err
@@ -751,7 +716,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op
func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
start := rp.Start().Impl().(*dentry)
d, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
if err != nil {
@@ -788,7 +753,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
start := rp.Start().Impl().(*dentry)
if rp.Done() {
@@ -970,7 +935,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return "", err
@@ -1000,7 +965,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return linux.Statx{}, err
@@ -1046,7 +1011,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
if _, err := fs.resolveLocked(ctx, rp, &ds); err != nil {
return nil, err
}
@@ -1057,7 +1022,7 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath
func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return nil, err
@@ -1073,7 +1038,7 @@ func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, si
func (fs *filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) {
var ds *[]*dentry
fs.renameMu.RLock()
- defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
d, err := fs.resolveLocked(ctx, rp, &ds)
if err != nil {
return "", err
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index c5fa9855b..d2526263c 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -23,10 +23,12 @@
// Lock order:
//
// filesystem.renameMu
-// dentry.dirMu
-// fileDescription.mu
-// filesystem.verityMu
-// dentry.hashMu
+// dentry.cachingMu
+// filesystem.cacheMu
+// dentry.dirMu
+// fileDescription.mu
+// filesystem.verityMu
+// dentry.hashMu
//
// Locking dentry.dirMu in multiple dentries requires that parent dentries are
// locked before child dentries, and that filesystem.renameMu is locked to
@@ -60,7 +62,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -97,6 +98,9 @@ const (
// sizeOfStringInt32 is the size for a 32 bit integer stored as string in
// extended attributes. The maximum value of a 32 bit integer has 10 digits.
sizeOfStringInt32 = 10
+
+ // defaultMaxCachedDentries is the default limit of dentry cache.
+ defaultMaxCachedDentries = uint64(1000)
)
var (
@@ -107,9 +111,10 @@ var (
// Mount option names for verityfs.
const (
- moptLowerPath = "lower_path"
- moptRootHash = "root_hash"
- moptRootName = "root_name"
+ moptLowerPath = "lower_path"
+ moptRootHash = "root_hash"
+ moptRootName = "root_name"
+ moptDentryCacheLimit = "dentry_cache_limit"
)
// HashAlgorithm is a type specifying the algorithm used to hash the file
@@ -189,6 +194,17 @@ type filesystem struct {
// dentries.
renameMu sync.RWMutex `state:"nosave"`
+ // cachedDentries contains all dentries with 0 references. (Due to race
+ // conditions, it may also contain dentries with non-zero references.)
+ // cachedDentriesLen is the number of dentries in cachedDentries. These
+ // fields are protected by cacheMu.
+ cacheMu sync.Mutex `state:"nosave"`
+ cachedDentries dentryList
+ cachedDentriesLen uint64
+
+ // maxCachedDentries is the maximum size of filesystem.cachedDentries.
+ maxCachedDentries uint64
+
// verityMu synchronizes enabling verity files, protects files or
// directories from being enabled by different threads simultaneously.
// It also ensures that verity does not access files that are being
@@ -199,6 +215,10 @@ type filesystem struct {
// is for the whole file system to ensure that no more than one file is
// enabled the same time.
verityMu sync.RWMutex `state:"nosave"`
+
+ // released is nonzero once filesystem.Release has been called. It is accessed
+ // with atomic memory operations.
+ released int32
}
// InternalFilesystemOptions may be passed as
@@ -239,7 +259,7 @@ func (FilesystemType) Release(ctx context.Context) {}
// mode, it returns EIO, otherwise it panic.
func (fs *filesystem) alertIntegrityViolation(msg string) error {
if fs.action == ErrorOnViolation {
- return syserror.EIO
+ return linuxerr.EIO
}
panic(msg)
}
@@ -267,6 +287,16 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
delete(mopts, moptRootName)
rootName = root
}
+ maxCachedDentries := defaultMaxCachedDentries
+ if str, ok := mopts[moptDentryCacheLimit]; ok {
+ delete(mopts, moptDentryCacheLimit)
+ maxCD, err := strconv.ParseUint(str, 10, 64)
+ if err != nil {
+ ctx.Warningf("verity.FilesystemType.GetFilesystem: invalid dentry cache limit: %s=%s", moptDentryCacheLimit, str)
+ return nil, nil, linuxerr.EINVAL
+ }
+ maxCachedDentries = maxCD
+ }
// Check for unparsed options.
if len(mopts) != 0 {
@@ -340,12 +370,16 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
action: iopts.Action,
opts: opts.Data,
allowRuntimeEnable: iopts.AllowRuntimeEnable,
+ maxCachedDentries: maxCachedDentries,
}
fs.vfsfs.Init(vfsObj, &fstype, fs)
// Construct the root dentry.
d := fs.newDentry()
- d.refs = 1
+ // Set the root's reference count to 2. One reference is returned to
+ // the caller, and the other is held by fs to prevent the root from
+ // being "cached" and subsequently evicted.
+ d.refs = 2
lowerVD := vfs.MakeVirtualDentry(lowerMount, lowerMount.Root())
lowerVD.IncRef()
d.lowerVD = lowerVD
@@ -520,7 +554,16 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// Release implements vfs.FilesystemImpl.Release.
func (fs *filesystem) Release(ctx context.Context) {
+ atomic.StoreInt32(&fs.released, 1)
fs.lowerMount.DecRef(ctx)
+
+ fs.renameMu.Lock()
+ fs.evictAllCachedDentriesLocked(ctx)
+ fs.renameMu.Unlock()
+
+ // An extra reference was held by the filesystem on the root to prevent
+ // it from being cached/evicted.
+ fs.rootDentry.DecRef(ctx)
}
// MountOptions implements vfs.FilesystemImpl.MountOptions.
@@ -534,6 +577,11 @@ func (fs *filesystem) MountOptions() string {
type dentry struct {
vfsd vfs.Dentry
+ // refs is the reference count. Each dentry holds a reference on its
+ // parent, even if disowned. When refs reaches 0, the dentry may be
+ // added to the cache or destroyed. If refs == -1, the dentry has
+ // already been destroyed. refs is accessed using atomic memory
+ // operations.
refs int64
// fs is the owning filesystem. fs is immutable.
@@ -588,13 +636,23 @@ type dentry struct {
// is protected by hashMu.
hashMu sync.RWMutex `state:"nosave"`
hash []byte
+
+ // cachingMu is used to synchronize concurrent dentry caching attempts on
+ // this dentry.
+ cachingMu sync.Mutex `state:"nosave"`
+
+ // If cached is true, dentryEntry links dentry into
+ // filesystem.cachedDentries. cached and dentryEntry are protected by
+ // cachingMu.
+ cached bool
+ dentryEntry
}
// newDentry creates a new dentry representing the given verity file. The
-// dentry initially has no references; it is the caller's responsibility to set
-// the dentry's reference count and/or call dentry.destroy() as appropriate.
-// The dentry is initially invalid in that it contains no underlying dentry;
-// the caller is responsible for setting them.
+// dentry initially has no references, but is not cached; it is the caller's
+// responsibility to set the dentry's reference count and/or call
+// dentry.destroy() as appropriate. The dentry is initially invalid in that it
+// contains no underlying dentry; the caller is responsible for setting them.
func (fs *filesystem) newDentry() *dentry {
d := &dentry{
fs: fs,
@@ -630,42 +688,23 @@ func (d *dentry) TryIncRef() bool {
// DecRef implements vfs.DentryImpl.DecRef.
func (d *dentry) DecRef(ctx context.Context) {
- r := atomic.AddInt64(&d.refs, -1)
- if d.LogRefs() {
- refsvfs2.LogDecRef(d, r)
- }
- if r == 0 {
- d.fs.renameMu.Lock()
- d.checkDropLocked(ctx)
- d.fs.renameMu.Unlock()
- } else if r < 0 {
- panic("verity.dentry.DecRef() called without holding a reference")
+ if d.decRefNoCaching() == 0 {
+ d.checkCachingLocked(ctx, false /* renameMuWriteLocked */)
}
}
-func (d *dentry) decRefLocked(ctx context.Context) {
+// decRefNoCaching decrements d's reference count without calling
+// d.checkCachingLocked, even if d's reference count reaches 0; callers are
+// responsible for ensuring that d.checkCachingLocked will be called later.
+func (d *dentry) decRefNoCaching() int64 {
r := atomic.AddInt64(&d.refs, -1)
if d.LogRefs() {
refsvfs2.LogDecRef(d, r)
}
- if r == 0 {
- d.checkDropLocked(ctx)
- } else if r < 0 {
- panic("verity.dentry.decRefLocked() called without holding a reference")
+ if r < 0 {
+ panic("verity.dentry.decRefNoCaching() called without holding a reference")
}
-}
-
-// checkDropLocked should be called after d's reference count becomes 0 or it
-// becomes deleted.
-func (d *dentry) checkDropLocked(ctx context.Context) {
- // Dentries with a positive reference count must be retained. Dentries
- // with a negative reference count have already been destroyed.
- if atomic.LoadInt64(&d.refs) != 0 {
- return
- }
- // Refs is still zero; destroy it.
- d.destroyLocked(ctx)
- return
+ return r
}
// destroyLocked destroys the dentry.
@@ -684,6 +723,12 @@ func (d *dentry) destroyLocked(ctx context.Context) {
panic("verity.dentry.destroyLocked() called with references on the dentry")
}
+ // Drop the reference held by d on its parent without recursively
+ // locking d.fs.renameMu.
+ if d.parent != nil && d.parent.decRefNoCaching() == 0 {
+ d.parent.checkCachingLocked(ctx, true /* renameMuWriteLocked */)
+ }
+
if d.lowerVD.Ok() {
d.lowerVD.DecRef(ctx)
}
@@ -696,7 +741,6 @@ func (d *dentry) destroyLocked(ctx context.Context) {
delete(d.parent.children, d.name)
}
d.parent.dirMu.Unlock()
- d.parent.decRefLocked(ctx)
}
refsvfs2.Unregister(d)
}
@@ -735,6 +779,140 @@ func (d *dentry) OnZeroWatches(context.Context) {
//TODO(b/159261227): Implement OnZeroWatches.
}
+// checkCachingLocked should be called after d's reference count becomes 0 or
+// it becomes disowned.
+//
+// For performance, checkCachingLocked can also be called after d's reference
+// count becomes non-zero, so that d can be removed from the LRU cache. This
+// may help in reducing the size of the cache and hence reduce evictions. Note
+// that this is not necessary for correctness.
+//
+// It may be called on a destroyed dentry. For example,
+// renameMu[R]UnlockAndCheckCaching may call checkCachingLocked multiple times
+// for the same dentry when the dentry is visited more than once in the same
+// operation. One of the calls may destroy the dentry, so subsequent calls will
+// do nothing.
+//
+// Preconditions: d.fs.renameMu must be locked for writing if
+// renameMuWriteLocked is true; it may be temporarily unlocked.
+func (d *dentry) checkCachingLocked(ctx context.Context, renameMuWriteLocked bool) {
+ d.cachingMu.Lock()
+ refs := atomic.LoadInt64(&d.refs)
+ if refs == -1 {
+ // Dentry has already been destroyed.
+ d.cachingMu.Unlock()
+ return
+ }
+ if refs > 0 {
+ // fs.cachedDentries is permitted to contain dentries with non-zero refs,
+ // which are skipped by fs.evictCachedDentryLocked() upon reaching the end
+ // of the LRU. But it is still beneficial to remove d from the cache as we
+ // are already holding d.cachingMu. Keeping a cleaner cache also reduces
+ // the number of evictions (which is expensive as it acquires fs.renameMu).
+ d.removeFromCacheLocked()
+ d.cachingMu.Unlock()
+ return
+ }
+
+ if atomic.LoadInt32(&d.fs.released) != 0 {
+ d.cachingMu.Unlock()
+ if !renameMuWriteLocked {
+ // Need to lock d.fs.renameMu to access d.parent. Lock it for writing as
+ // needed by d.destroyLocked() later.
+ d.fs.renameMu.Lock()
+ defer d.fs.renameMu.Unlock()
+ }
+ if d.parent != nil {
+ d.parent.dirMu.Lock()
+ delete(d.parent.children, d.name)
+ d.parent.dirMu.Unlock()
+ }
+ d.destroyLocked(ctx) // +checklocksforce: see above.
+ return
+ }
+
+ d.fs.cacheMu.Lock()
+ // If d is already cached, just move it to the front of the LRU.
+ if d.cached {
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentries.PushFront(d)
+ d.fs.cacheMu.Unlock()
+ d.cachingMu.Unlock()
+ return
+ }
+ // Cache the dentry, then evict the least recently used cached dentry if
+ // the cache becomes over-full.
+ d.fs.cachedDentries.PushFront(d)
+ d.fs.cachedDentriesLen++
+ d.cached = true
+ shouldEvict := d.fs.cachedDentriesLen > d.fs.maxCachedDentries
+ d.fs.cacheMu.Unlock()
+ d.cachingMu.Unlock()
+
+ if shouldEvict {
+ if !renameMuWriteLocked {
+ // Need to lock d.fs.renameMu for writing as needed by
+ // d.evictCachedDentryLocked().
+ d.fs.renameMu.Lock()
+ defer d.fs.renameMu.Unlock()
+ }
+ d.fs.evictCachedDentryLocked(ctx) // +checklocksforce: see above.
+ }
+}
+
+// Preconditions: d.cachingMu must be locked.
+func (d *dentry) removeFromCacheLocked() {
+ if d.cached {
+ d.fs.cacheMu.Lock()
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentriesLen--
+ d.fs.cacheMu.Unlock()
+ d.cached = false
+ }
+}
+
+// Precondition: fs.renameMu must be locked for writing; it may be temporarily
+// unlocked.
+// +checklocks:fs.renameMu
+func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) {
+ for fs.cachedDentriesLen != 0 {
+ fs.evictCachedDentryLocked(ctx)
+ }
+}
+
+// Preconditions:
+// * fs.renameMu must be locked for writing; it may be temporarily unlocked.
+// +checklocks:fs.renameMu
+func (fs *filesystem) evictCachedDentryLocked(ctx context.Context) {
+ fs.cacheMu.Lock()
+ victim := fs.cachedDentries.Back()
+ fs.cacheMu.Unlock()
+ if victim == nil {
+ // fs.cachedDentries may have become empty between when it was
+ // checked and when we locked fs.cacheMu.
+ return
+ }
+
+ victim.cachingMu.Lock()
+ victim.removeFromCacheLocked()
+ // victim.refs may have become non-zero from an earlier path resolution
+ // since it was inserted into fs.cachedDentries.
+ if atomic.LoadInt64(&victim.refs) != 0 {
+ victim.cachingMu.Unlock()
+ return
+ }
+ if victim.parent != nil {
+ victim.parent.dirMu.Lock()
+ // Note that victim can't be a mount point (in any mount
+ // namespace), since VFS holds references on mount points.
+ fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd)
+ delete(victim.parent.children, victim.name)
+ victim.parent.dirMu.Unlock()
+ }
+ victim.cachingMu.Unlock()
+ victim.destroyLocked(ctx) // +checklocksforce: owned as precondition, victim.fs == fs.
+}
+
func (d *dentry) isSymlink() bool {
return atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFLNK
}
@@ -1091,6 +1269,21 @@ func (fd *fileDescription) enableVerity(ctx context.Context) (uintptr, error) {
return 0, fd.d.fs.alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds")
}
+ // Populate children names here. We cannot rely on the children
+ // dentries to populate parent dentry's children names, because the
+ // parent dentry may be destroyed before users enable verity if its ref
+ // count drops to zero.
+ if fd.d.isDir() {
+ if err := fd.IterDirents(ctx, vfs.IterDirentsCallbackFunc(func(dirent vfs.Dirent) error {
+ if dirent.Name != "." && dirent.Name != ".." {
+ fd.d.childrenNames[dirent.Name] = struct{}{}
+ }
+ return nil
+ })); err != nil {
+ return 0, err
+ }
+ }
+
hash, dataSize, err := fd.generateMerkleLocked(ctx)
if err != nil {
return 0, err
@@ -1118,9 +1311,6 @@ func (fd *fileDescription) enableVerity(ctx context.Context) (uintptr, error) {
}); err != nil {
return 0, err
}
-
- // Add the current child's name to parent's childrenNames.
- fd.d.parent.childrenNames[fd.d.name] = struct{}{}
}
// Record the size of the data being hashed for fd.
@@ -1215,7 +1405,7 @@ func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.
case linux.FS_IOC_GETFLAGS:
return fd.verityFlags(ctx, args[2].Pointer())
default:
- return 0, syserror.ENOSYS
+ return 0, linuxerr.ENOSYS
}
}
diff --git a/pkg/sentry/hostmm/BUILD b/pkg/sentry/hostmm/BUILD
index 66fa1ad40..03c8e2f38 100644
--- a/pkg/sentry/hostmm/BUILD
+++ b/pkg/sentry/hostmm/BUILD
@@ -12,8 +12,7 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/fd",
- "//pkg/hostarch",
+ "//pkg/eventfd",
"//pkg/log",
"@org_golang_x_sys//unix:go_default_library",
],
diff --git a/pkg/sentry/hostmm/hostmm.go b/pkg/sentry/hostmm/hostmm.go
index 285ea9050..5df06a60f 100644
--- a/pkg/sentry/hostmm/hostmm.go
+++ b/pkg/sentry/hostmm/hostmm.go
@@ -21,9 +21,7 @@ import (
"os"
"path"
- "golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/fd"
- "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/eventfd"
"gvisor.dev/gvisor/pkg/log"
)
@@ -54,7 +52,7 @@ func NotifyCurrentMemcgPressureCallback(f func(), level string) (func(), error)
}
defer eventControlFile.Close()
- eventFD, err := newEventFD()
+ eventFD, err := eventfd.Create()
if err != nil {
return nil, err
}
@@ -75,20 +73,11 @@ func NotifyCurrentMemcgPressureCallback(f func(), level string) (func(), error)
const stopVal = 1 << 63
stopCh := make(chan struct{})
go func() { // S/R-SAFE: f provides synchronization if necessary
- rw := fd.NewReadWriter(eventFD.FD())
- var buf [sizeofUint64]byte
for {
- n, err := rw.Read(buf[:])
+ val, err := eventFD.Read()
if err != nil {
- if err == unix.EINTR {
- continue
- }
panic(fmt.Sprintf("failed to read from memory pressure level eventfd: %v", err))
}
- if n != sizeofUint64 {
- panic(fmt.Sprintf("short read from memory pressure level eventfd: got %d bytes, wanted %d", n, sizeofUint64))
- }
- val := hostarch.ByteOrder.Uint64(buf[:])
if val >= stopVal {
// Assume this was due to the notifier's "destructor" (the
// function returned by NotifyCurrentMemcgPressureCallback
@@ -101,30 +90,7 @@ func NotifyCurrentMemcgPressureCallback(f func(), level string) (func(), error)
}
}()
return func() {
- rw := fd.NewReadWriter(eventFD.FD())
- var buf [sizeofUint64]byte
- hostarch.ByteOrder.PutUint64(buf[:], stopVal)
- for {
- n, err := rw.Write(buf[:])
- if err != nil {
- if err == unix.EINTR {
- continue
- }
- panic(fmt.Sprintf("failed to write to memory pressure level eventfd: %v", err))
- }
- if n != sizeofUint64 {
- panic(fmt.Sprintf("short write to memory pressure level eventfd: got %d bytes, wanted %d", n, sizeofUint64))
- }
- break
- }
+ eventFD.Write(stopVal)
<-stopCh
}, nil
}
-
-func newEventFD() (*fd.FD, error) {
- f, _, e := unix.Syscall(unix.SYS_EVENTFD2, 0, 0, 0)
- if e != 0 {
- return nil, fmt.Errorf("failed to create eventfd: %v", e)
- }
- return fd.New(int(f)), nil
-}
diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD
index 5bba9de0b..2363cec5f 100644
--- a/pkg/sentry/inet/BUILD
+++ b/pkg/sentry/inet/BUILD
@@ -1,13 +1,26 @@
load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(
default_visibility = ["//:sandbox"],
licenses = ["notice"],
)
+go_template_instance(
+ name = "atomicptr_netns",
+ out = "atomicptr_netns_unsafe.go",
+ package = "inet",
+ prefix = "Namespace",
+ template = "//pkg/sync/atomicptr:generic_atomicptr",
+ types = {
+ "Value": "Namespace",
+ },
+)
+
go_library(
name = "inet",
srcs = [
+ "atomicptr_netns_unsafe.go",
"context.go",
"inet.go",
"namespace.go",
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index e4e0dc04f..c0f13bf52 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -268,6 +268,7 @@ go_library(
"//pkg/sentry/mm",
"//pkg/sentry/pgalloc",
"//pkg/sentry/platform",
+ "//pkg/sentry/seccheck",
"//pkg/sentry/socket/netlink/port",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/time",
@@ -281,7 +282,6 @@ go_library(
"//pkg/state/wire",
"//pkg/sync",
"//pkg/syserr",
- "//pkg/syserror",
"//pkg/tcpip",
"//pkg/tcpip/stack",
"//pkg/usermem",
diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD
index 7a1a36454..9aa03f506 100644
--- a/pkg/sentry/kernel/auth/BUILD
+++ b/pkg/sentry/kernel/auth/BUILD
@@ -66,6 +66,5 @@ go_library(
"//pkg/errors/linuxerr",
"//pkg/log",
"//pkg/sync",
- "//pkg/syserror",
],
)
diff --git a/pkg/sentry/kernel/cgroup.go b/pkg/sentry/kernel/cgroup.go
index c93ef6ac1..a0e291f58 100644
--- a/pkg/sentry/kernel/cgroup.go
+++ b/pkg/sentry/kernel/cgroup.go
@@ -196,6 +196,7 @@ func (r *CgroupRegistry) FindHierarchy(ctypes []CgroupControllerType) *vfs.Files
// uniqueness of controllers enforced by Register, drop the
// dying hierarchy now. The eventual unregister by the FS
// teardown will become a no-op.
+ r.unregisterLocked(h.id)
return nil
}
return h.fs
diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD
index 564c3d42e..f240a68aa 100644
--- a/pkg/sentry/kernel/eventfd/BUILD
+++ b/pkg/sentry/kernel/eventfd/BUILD
@@ -9,13 +9,13 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/fdnotifier",
"//pkg/hostarch",
"//pkg/sentry/fs",
"//pkg/sentry/fs/anon",
"//pkg/sentry/fs/fsutil",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/sentry/kernel/eventfd/eventfd.go b/pkg/sentry/kernel/eventfd/eventfd.go
index 4466fbc9d..5ea44a2c2 100644
--- a/pkg/sentry/kernel/eventfd/eventfd.go
+++ b/pkg/sentry/kernel/eventfd/eventfd.go
@@ -22,13 +22,13 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/anon"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -145,7 +145,7 @@ func (e *EventOperations) hostRead(ctx context.Context, dst usermem.IOSequence)
if _, err := unix.Read(e.hostfd, buf[:]); err != nil {
if err == unix.EWOULDBLOCK {
- return syserror.ErrWouldBlock
+ return linuxerr.ErrWouldBlock
}
return err
}
@@ -165,7 +165,7 @@ func (e *EventOperations) read(ctx context.Context, dst usermem.IOSequence) erro
// We can't complete the read if the value is currently zero.
if e.val == 0 {
e.mu.Unlock()
- return syserror.ErrWouldBlock
+ return linuxerr.ErrWouldBlock
}
// Update the value based on the mode the event is operating in.
@@ -198,7 +198,7 @@ func (e *EventOperations) hostWrite(val uint64) error {
hostarch.ByteOrder.PutUint64(buf[:], val)
_, err := unix.Write(e.hostfd, buf[:])
if err == unix.EWOULDBLOCK {
- return syserror.ErrWouldBlock
+ return linuxerr.ErrWouldBlock
}
return err
}
@@ -230,7 +230,7 @@ func (e *EventOperations) Signal(val uint64) error {
// uint64 minus 1.
if val > math.MaxUint64-1-e.val {
e.mu.Unlock()
- return syserror.ErrWouldBlock
+ return linuxerr.ErrWouldBlock
}
e.val += val
diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD
index cfdea5cf7..c897e3a5f 100644
--- a/pkg/sentry/kernel/futex/BUILD
+++ b/pkg/sentry/kernel/futex/BUILD
@@ -42,7 +42,6 @@ go_library(
"//pkg/log",
"//pkg/sentry/memmap",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go
index f5c364c96..2c9ea65aa 100644
--- a/pkg/sentry/kernel/futex/futex.go
+++ b/pkg/sentry/kernel/futex/futex.go
@@ -24,7 +24,6 @@ import (
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
// KeyKind indicates the type of a Key.
@@ -166,7 +165,7 @@ func atomicOp(t Target, addr hostarch.Addr, opIn uint32) (bool, error) {
case linux.FUTEX_OP_XOR:
newVal = oldVal ^ opArg
default:
- return false, syserror.ENOSYS
+ return false, linuxerr.ENOSYS
}
prev, err := t.CompareAndSwapUint32(addr, oldVal, newVal)
if err != nil {
@@ -192,7 +191,7 @@ func atomicOp(t Target, addr hostarch.Addr, opIn uint32) (bool, error) {
case linux.FUTEX_OP_CMP_GE:
return oldVal >= cmpArg, nil
default:
- return false, syserror.ENOSYS
+ return false, linuxerr.ENOSYS
}
}
diff --git a/pkg/sentry/kernel/ipc/object.go b/pkg/sentry/kernel/ipc/object.go
index 387b35e7e..facd157c7 100644
--- a/pkg/sentry/kernel/ipc/object.go
+++ b/pkg/sentry/kernel/ipc/object.go
@@ -19,6 +19,8 @@ package ipc
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
)
@@ -113,3 +115,36 @@ func (o *Object) CheckPermissions(creds *auth.Credentials, req fs.PermMask) bool
}
return creds.HasCapabilityIn(linux.CAP_IPC_OWNER, o.UserNS)
}
+
+// Set modifies attributes for an IPC object. See *ctl(IPC_SET).
+//
+// Precondition: Mechanism.mu must be held.
+func (o *Object) Set(ctx context.Context, perm *linux.IPCPerm) error {
+ creds := auth.CredentialsFromContext(ctx)
+ uid := creds.UserNamespace.MapToKUID(auth.UID(perm.UID))
+ gid := creds.UserNamespace.MapToKGID(auth.GID(perm.GID))
+ if !uid.Ok() || !gid.Ok() {
+ // The man pages don't specify an errno for invalid uid/gid, but EINVAL
+ // is generally used for invalid arguments.
+ return linuxerr.EINVAL
+ }
+
+ if !o.CheckOwnership(creds) {
+ // "The argument cmd has the value IPC_SET or IPC_RMID, but the
+ // effective user ID of the calling process is not the creator (as
+ // found in msg_perm.cuid) or the owner (as found in msg_perm.uid)
+ // of the message queue, and the caller is not privileged (Linux:
+ // does not have the CAP_SYS_ADMIN capability)."
+ return linuxerr.EPERM
+ }
+
+ // User may only modify the lower 9 bits of the mode. All the other bits are
+ // always 0 for the underlying inode.
+ mode := linux.FileMode(perm.Mode & 0x1ff)
+
+ o.Perms = fs.FilePermsFromMode(mode)
+ o.Owner.UID = uid
+ o.Owner.GID = gid
+
+ return nil
+}
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index df5160b67..f913d25db 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -78,11 +78,19 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
-// VFS2Enabled is set to true when VFS2 is enabled. Added as a global for allow
-// easy access everywhere. To be removed once VFS2 becomes the default.
+// VFS2Enabled is set to true when VFS2 is enabled. Added as a global to allow
+// easy access everywhere.
+//
+// TODO(gvisor.dev/issue/1624): Remove when VFS1 is no longer used.
var VFS2Enabled = false
-// FUSEEnabled is set to true when FUSE is enabled. Added as a global for allow
+// LISAFSEnabled is set to true when lisafs protocol is enabled. Added as a
+// global to allow easy access everywhere.
+//
+// TODO(gvisor.dev/issue/6319): Remove when lisafs is default.
+var LISAFSEnabled = false
+
+// FUSEEnabled is set to true when FUSE is enabled. Added as a global to allow
// easy access everywhere. To be removed once FUSE is completed.
var FUSEEnabled = false
diff --git a/pkg/sentry/kernel/msgqueue/msgqueue.go b/pkg/sentry/kernel/msgqueue/msgqueue.go
index fab396d7c..c7c5e41fb 100644
--- a/pkg/sentry/kernel/msgqueue/msgqueue.go
+++ b/pkg/sentry/kernel/msgqueue/msgqueue.go
@@ -129,6 +129,16 @@ type Message struct {
Size uint64
}
+func (m *Message) makeCopy() *Message {
+ new := &Message{
+ Type: m.Type,
+ Size: m.Size,
+ }
+ new.Text = make([]byte, len(m.Text))
+ copy(new.Text, m.Text)
+ return new
+}
+
// Blocker is used for blocking Queue.Send, and Queue.Receive calls that serves
// as an abstracted version of kernel.Task. kernel.Task is not directly used to
// prevent circular dependencies.
@@ -206,6 +216,48 @@ func (r *Registry) FindByID(id ipc.ID) (*Queue, error) {
return mech.(*Queue), nil
}
+// IPCInfo reports global parameters for message queues. See msgctl(IPC_INFO).
+func (r *Registry) IPCInfo(ctx context.Context) *linux.MsgInfo {
+ return &linux.MsgInfo{
+ MsgPool: linux.MSGPOOL,
+ MsgMap: linux.MSGMAP,
+ MsgMax: linux.MSGMAX,
+ MsgMnb: linux.MSGMNB,
+ MsgMni: linux.MSGMNI,
+ MsgSsz: linux.MSGSSZ,
+ MsgTql: linux.MSGTQL,
+ MsgSeg: linux.MSGSEG,
+ }
+}
+
+// MsgInfo reports global parameters for message queues. See msgctl(MSG_INFO).
+func (r *Registry) MsgInfo(ctx context.Context) *linux.MsgInfo {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ var messages, bytes uint64
+ r.reg.ForAllObjects(
+ func(o ipc.Mechanism) {
+ q := o.(*Queue)
+ q.mu.Lock()
+ messages += q.messageCount
+ bytes += q.byteCount
+ q.mu.Unlock()
+ },
+ )
+
+ return &linux.MsgInfo{
+ MsgPool: int32(r.reg.ObjectCount()),
+ MsgMap: int32(messages),
+ MsgTql: int32(bytes),
+ MsgMax: linux.MSGMAX,
+ MsgMnb: linux.MSGMNB,
+ MsgMni: linux.MSGMNI,
+ MsgSsz: linux.MSGSSZ,
+ MsgSeg: linux.MSGSEG,
+ }
+}
+
// Send appends a message to the message queue, and returns an error if sending
// fails. See msgsnd(2).
func (q *Queue) Send(ctx context.Context, m Message, b Blocker, wait bool, pid int32) error {
@@ -413,7 +465,7 @@ func (q *Queue) Copy(mType int64) (*Message, error) {
if msg == nil {
return nil, linuxerr.ENOMSG
}
- return msg, nil
+ return msg.makeCopy(), nil
}
// msgOfType returns the first message with the specified type, nil if no
@@ -465,6 +517,73 @@ func (q *Queue) msgAtIndex(mType int64) *Message {
return msg
}
+// Set modifies some values of the queue. See msgctl(IPC_SET).
+func (q *Queue) Set(ctx context.Context, ds *linux.MsqidDS) error {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ creds := auth.CredentialsFromContext(ctx)
+ if ds.MsgQbytes > maxQueueBytes && !creds.HasCapabilityIn(linux.CAP_SYS_RESOURCE, q.obj.UserNS) {
+ // "An attempt (IPC_SET) was made to increase msg_qbytes beyond the
+ // system parameter MSGMNB, but the caller is not privileged (Linux:
+ // does not have the CAP_SYS_RESOURCE capability)."
+ return linuxerr.EPERM
+ }
+
+ if err := q.obj.Set(ctx, &ds.MsgPerm); err != nil {
+ return err
+ }
+
+ q.maxBytes = ds.MsgQbytes
+ q.changeTime = ktime.NowFromContext(ctx)
+ return nil
+}
+
+// Stat returns a MsqidDS object filled with information about the queue. See
+// msgctl(IPC_STAT) and msgctl(MSG_STAT).
+func (q *Queue) Stat(ctx context.Context) (*linux.MsqidDS, error) {
+ return q.stat(ctx, fs.PermMask{Read: true})
+}
+
+// StatAny is similar to Queue.Stat, but doesn't require read permission. See
+// msgctl(MSG_STAT_ANY).
+func (q *Queue) StatAny(ctx context.Context) (*linux.MsqidDS, error) {
+ return q.stat(ctx, fs.PermMask{})
+}
+
+// stat returns a MsqidDS object filled with information about the queue. An
+// error is returned if the user doesn't have the specified permissions.
+func (q *Queue) stat(ctx context.Context, mask fs.PermMask) (*linux.MsqidDS, error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ creds := auth.CredentialsFromContext(ctx)
+ if !q.obj.CheckPermissions(creds, mask) {
+ // "The caller must have read permission on the message queue."
+ return nil, linuxerr.EACCES
+ }
+
+ return &linux.MsqidDS{
+ MsgPerm: linux.IPCPerm{
+ Key: uint32(q.obj.Key),
+ UID: uint32(creds.UserNamespace.MapFromKUID(q.obj.Owner.UID)),
+ GID: uint32(creds.UserNamespace.MapFromKGID(q.obj.Owner.GID)),
+ CUID: uint32(creds.UserNamespace.MapFromKUID(q.obj.Creator.UID)),
+ CGID: uint32(creds.UserNamespace.MapFromKGID(q.obj.Creator.GID)),
+ Mode: uint16(q.obj.Perms.LinuxMode()),
+ Seq: 0, // IPC sequences not supported.
+ },
+ MsgStime: q.sendTime.TimeT(),
+ MsgRtime: q.receiveTime.TimeT(),
+ MsgCtime: q.changeTime.TimeT(),
+ MsgCbytes: q.byteCount,
+ MsgQnum: q.messageCount,
+ MsgQbytes: q.maxBytes,
+ MsgLspid: q.sendPID,
+ MsgLrpid: q.receivePID,
+ }, nil
+}
+
// Lock implements ipc.Mechanism.Lock.
func (q *Queue) Lock() {
q.mu.Lock()
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
index 94ebac7c5..5b2bac783 100644
--- a/pkg/sentry/kernel/pipe/BUILD
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -31,7 +31,6 @@ go_library(
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/vfs",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
"@org_golang_x_sys//unix:go_default_library",
@@ -51,7 +50,6 @@ go_test(
"//pkg/errors/linuxerr",
"//pkg/sentry/contexttest",
"//pkg/sentry/fs",
- "//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
],
diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go
index 08786d704..615591507 100644
--- a/pkg/sentry/kernel/pipe/node.go
+++ b/pkg/sentry/kernel/pipe/node.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
// inodeOperations implements fs.InodeOperations for pipes.
@@ -95,7 +94,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
if i.p.isNamed && !flags.NonBlocking && !i.p.HasWriters() {
if !waitFor(&i.mu, &i.wWakeup, ctx) {
r.DecRef(ctx)
- return nil, syserror.ErrInterrupted
+ return nil, linuxerr.ErrInterrupted
}
}
@@ -118,7 +117,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
if !waitFor(&i.mu, &i.rWakeup, ctx) {
w.DecRef(ctx)
- return nil, syserror.ErrInterrupted
+ return nil, linuxerr.ErrInterrupted
}
}
return w, nil
diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go
index d25cf658e..31bd7910a 100644
--- a/pkg/sentry/kernel/pipe/node_test.go
+++ b/pkg/sentry/kernel/pipe/node_test.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/syserror"
)
type sleeper struct {
@@ -240,7 +239,7 @@ func TestBlockedOpenIsCancellable(t *testing.T) {
// If the cancel on the sleeper didn't work, the open for read would never
// return.
res := <-done
- if res.error != syserror.ErrInterrupted {
+ if res.error != linuxerr.ErrInterrupted {
t.Fatalf("Cancellation didn't cause GetFile to return fs.ErrInterrupted, got %v.",
res.error)
}
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 85e3ce9f4..86beee6fe 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -27,7 +27,6 @@ import (
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -201,7 +200,7 @@ func (p *Pipe) peekLocked(count int64, f func(safemem.BlockSeq) (uint64, error))
if !p.HasWriters() {
return 0, io.EOF
}
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
count = p.size
}
@@ -250,7 +249,7 @@ func (p *Pipe) writeLocked(count int64, f func(safemem.BlockSeq) (uint64, error)
avail := p.max - p.size
if avail == 0 {
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
short := false
if count > avail {
@@ -258,7 +257,7 @@ func (p *Pipe) writeLocked(count int64, f func(safemem.BlockSeq) (uint64, error)
// (PIPE_BUF) be atomic, but requires no atomicity for writes
// larger than this.
if count <= atomicIOBytes {
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
count = avail
short = true
@@ -307,7 +306,7 @@ func (p *Pipe) writeLocked(count int64, f func(safemem.BlockSeq) (uint64, error)
// If we shortened the write, adjust the returned error appropriately.
if short {
- return done, syserror.ErrWouldBlock
+ return done, linuxerr.ErrWouldBlock
}
return done, nil
diff --git a/pkg/sentry/kernel/pipe/pipe_test.go b/pkg/sentry/kernel/pipe/pipe_test.go
index 867f4a76b..aa3ab305d 100644
--- a/pkg/sentry/kernel/pipe/pipe_test.go
+++ b/pkg/sentry/kernel/pipe/pipe_test.go
@@ -18,8 +18,8 @@ import (
"bytes"
"testing"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -51,8 +51,8 @@ func TestPipeReadBlock(t *testing.T) {
defer w.DecRef(ctx)
n, err := r.Readv(ctx, usermem.BytesIOSequence(make([]byte, 1)))
- if n != 0 || err != syserror.ErrWouldBlock {
- t.Fatalf("Readv: got (%d, %v), wanted (0, %v)", n, err, syserror.ErrWouldBlock)
+ if n != 0 || err != linuxerr.ErrWouldBlock {
+ t.Fatalf("Readv: got (%d, %v), wanted (0, %v)", n, err, linuxerr.ErrWouldBlock)
}
}
@@ -67,7 +67,7 @@ func TestPipeWriteBlock(t *testing.T) {
msg := make([]byte, capacity+1)
n, err := w.Writev(ctx, usermem.BytesIOSequence(msg))
- if wantN, wantErr := int64(capacity), syserror.ErrWouldBlock; n != wantN || err != wantErr {
+ if wantN, wantErr := int64(capacity), linuxerr.ErrWouldBlock; n != wantN || err != wantErr {
t.Fatalf("Writev: got (%d, %v), wanted (%d, %v)", n, err, wantN, wantErr)
}
}
@@ -102,7 +102,7 @@ func TestPipeWriteUntilEnd(t *testing.T) {
for {
n, err := r.Readv(ctx, dst)
dst = dst.DropFirst64(n)
- if err == syserror.ErrWouldBlock {
+ if err == linuxerr.ErrWouldBlock {
select {
case <-ch:
continue
@@ -129,7 +129,7 @@ func TestPipeWriteUntilEnd(t *testing.T) {
for src.NumBytes() != 0 {
n, err := w.Writev(ctx, src)
src = src.DropFirst64(n)
- if err == syserror.ErrWouldBlock {
+ if err == linuxerr.ErrWouldBlock {
<-ch
continue
}
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
index 077d5fd7f..a6f1989f5 100644
--- a/pkg/sentry/kernel/pipe/vfs.go
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -23,7 +23,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -121,7 +120,7 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s
// writer, we have to wait for a writer to open the other end.
if vp.pipe.isNamed && statusFlags&linux.O_NONBLOCK == 0 && !vp.pipe.HasWriters() && !waitFor(&vp.mu, &vp.wWakeup, ctx) {
fd.DecRef(ctx)
- return nil, syserror.EINTR
+ return nil, linuxerr.EINTR
}
case writable:
@@ -137,7 +136,7 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s
// Wait for a reader to open the other end.
if !waitFor(&vp.mu, &vp.rWakeup, ctx) {
fd.DecRef(ctx)
- return nil, syserror.EINTR
+ return nil, linuxerr.EINTR
}
}
diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go
index 079294f81..717c9a6b3 100644
--- a/pkg/sentry/kernel/ptrace.go
+++ b/pkg/sentry/kernel/ptrace.go
@@ -23,7 +23,6 @@ import (
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/mm"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -465,7 +464,7 @@ func (t *Task) ptraceUnfreezeLocked() {
// stop.
func (t *Task) ptraceUnstop(mode ptraceSyscallMode, singlestep bool, sig linux.Signal) error {
if sig != 0 && !sig.IsValid() {
- return syserror.EIO
+ return linuxerr.EIO
}
t.tg.pidns.owner.mu.Lock()
defer t.tg.pidns.owner.mu.Unlock()
@@ -532,7 +531,7 @@ func (t *Task) ptraceAttach(target *Task, seize bool, opts uintptr) error {
}
if seize {
if err := target.ptraceSetOptionsLocked(opts); err != nil {
- return syserror.EIO
+ return linuxerr.EIO
}
}
target.ptraceTracer.Store(t)
@@ -569,7 +568,7 @@ func (t *Task) ptraceAttach(target *Task, seize bool, opts uintptr) error {
// ptrace stop.
func (t *Task) ptraceDetach(target *Task, sig linux.Signal) error {
if sig != 0 && !sig.IsValid() {
- return syserror.EIO
+ return linuxerr.EIO
}
t.tg.pidns.owner.mu.Lock()
defer t.tg.pidns.owner.mu.Unlock()
@@ -967,7 +966,7 @@ func (t *Task) ptraceInterrupt(target *Task) error {
return linuxerr.ESRCH
}
if !target.ptraceSeized {
- return syserror.EIO
+ return linuxerr.EIO
}
target.tg.signalHandlers.mu.Lock()
defer target.tg.signalHandlers.mu.Unlock()
@@ -1030,7 +1029,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data hostarch.Addr) error {
if req == linux.PTRACE_ATTACH || req == linux.PTRACE_SEIZE {
seize := req == linux.PTRACE_SEIZE
if seize && addr != 0 {
- return syserror.EIO
+ return linuxerr.EIO
}
return t.ptraceAttach(target, seize, uintptr(data))
}
@@ -1120,13 +1119,13 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data hostarch.Addr) error {
t.tg.pidns.owner.mu.RLock()
defer t.tg.pidns.owner.mu.RUnlock()
if !target.ptraceSeized {
- return syserror.EIO
+ return linuxerr.EIO
}
if target.ptraceSiginfo == nil {
- return syserror.EIO
+ return linuxerr.EIO
}
if target.ptraceSiginfo.Code>>8 != linux.PTRACE_EVENT_STOP {
- return syserror.EIO
+ return linuxerr.EIO
}
target.tg.signalHandlers.mu.Lock()
defer target.tg.signalHandlers.mu.Unlock()
diff --git a/pkg/sentry/kernel/ptrace_amd64.go b/pkg/sentry/kernel/ptrace_amd64.go
index 63422e155..564add01b 100644
--- a/pkg/sentry/kernel/ptrace_amd64.go
+++ b/pkg/sentry/kernel/ptrace_amd64.go
@@ -19,8 +19,8 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -88,6 +88,6 @@ func (t *Task) ptraceArch(target *Task, req int64, addr, data hostarch.Addr) err
return err
default:
- return syserror.EIO
+ return linuxerr.EIO
}
}
diff --git a/pkg/sentry/kernel/ptrace_arm64.go b/pkg/sentry/kernel/ptrace_arm64.go
index 27514d67b..7c2b94339 100644
--- a/pkg/sentry/kernel/ptrace_arm64.go
+++ b/pkg/sentry/kernel/ptrace_arm64.go
@@ -18,11 +18,11 @@
package kernel
import (
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
- "gvisor.dev/gvisor/pkg/syserror"
)
// ptraceArch implements arch-specific ptrace commands.
func (t *Task) ptraceArch(target *Task, req int64, addr, data hostarch.Addr) error {
- return syserror.EIO
+ return linuxerr.EIO
}
diff --git a/pkg/sentry/kernel/seccomp.go b/pkg/sentry/kernel/seccomp.go
index 54ca43c2e..0d66648c3 100644
--- a/pkg/sentry/kernel/seccomp.go
+++ b/pkg/sentry/kernel/seccomp.go
@@ -18,9 +18,9 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bpf"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/syserror"
)
const maxSyscallFilterInstructions = 1 << 15
@@ -176,7 +176,7 @@ func (t *Task) AppendSyscallFilter(p bpf.Program, syncAll bool) error {
}
if totalLength > maxSyscallFilterInstructions {
- return syserror.ENOMEM
+ return linuxerr.ENOMEM
}
newFilters = append(newFilters, p)
diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD
index 2ae08ed12..6aa74219e 100644
--- a/pkg/sentry/kernel/semaphore/BUILD
+++ b/pkg/sentry/kernel/semaphore/BUILD
@@ -31,7 +31,6 @@ go_library(
"//pkg/sentry/kernel/ipc",
"//pkg/sentry/kernel/time",
"//pkg/sync",
- "//pkg/syserror",
],
)
@@ -43,9 +42,9 @@ go_test(
deps = [
"//pkg/abi/linux", # keep
"//pkg/context", # keep
+ "//pkg/errors/linuxerr", #keep
"//pkg/sentry/contexttest", # keep
"//pkg/sentry/kernel/auth", # keep
"//pkg/sentry/kernel/ipc", # keep
- "//pkg/syserror", # keep
],
)
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go
index 8610d3fc1..28e466948 100644
--- a/pkg/sentry/kernel/semaphore/semaphore.go
+++ b/pkg/sentry/kernel/semaphore/semaphore.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/ipc"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
const (
@@ -151,10 +150,10 @@ func (r *Registry) FindOrCreate(ctx context.Context, key ipc.Key, nsems int32, m
// Map reg.objects and map indexes in a registry are of the same size,
// check map reg.objects only here for the system limit.
if r.reg.ObjectCount() >= setsMax {
- return nil, syserror.ENOSPC
+ return nil, linuxerr.ENOSPC
}
if r.totalSems() > int(semsTotalMax-nsems) {
- return nil, syserror.ENOSPC
+ return nil, linuxerr.ENOSPC
}
// Finally create a new set.
@@ -337,19 +336,15 @@ func (s *Set) Size() int {
return len(s.sems)
}
-// Change changes some fields from the set atomically.
-func (s *Set) Change(ctx context.Context, creds *auth.Credentials, owner fs.FileOwner, perms fs.FilePermissions) error {
+// Set modifies attributes for a semaphore set. See semctl(IPC_SET).
+func (s *Set) Set(ctx context.Context, ds *linux.SemidDS) error {
s.mu.Lock()
defer s.mu.Unlock()
- // "The effective UID of the calling process must match the owner or creator
- // of the semaphore set, or the caller must be privileged."
- if !s.obj.CheckOwnership(creds) {
- return linuxerr.EACCES
+ if err := s.obj.Set(ctx, &ds.SemPerm); err != nil {
+ return err
}
- s.obj.Owner = owner
- s.obj.Perms = perms
s.changeTime = ktime.NowFromContext(ctx)
return nil
}
@@ -549,7 +544,7 @@ func (s *Set) ExecuteOps(ctx context.Context, ops []linux.Sembuf, creds *auth.Cr
// Did it race with a removal operation?
if s.dead {
- return nil, 0, syserror.EIDRM
+ return nil, 0, linuxerr.EIDRM
}
// Validate the operations.
@@ -588,7 +583,7 @@ func (s *Set) executeOps(ctx context.Context, ops []linux.Sembuf, pid int32) (ch
if tmpVals[op.SemNum] != 0 {
// Semaphore isn't 0, must wait.
if op.SemFlg&linux.IPC_NOWAIT != 0 {
- return nil, 0, syserror.ErrWouldBlock
+ return nil, 0, linuxerr.ErrWouldBlock
}
w := newWaiter(op.SemOp)
@@ -604,7 +599,7 @@ func (s *Set) executeOps(ctx context.Context, ops []linux.Sembuf, pid int32) (ch
if -op.SemOp > tmpVals[op.SemNum] {
// Not enough resources, must wait.
if op.SemFlg&linux.IPC_NOWAIT != 0 {
- return nil, 0, syserror.ErrWouldBlock
+ return nil, 0, linuxerr.ErrWouldBlock
}
w := newWaiter(op.SemOp)
diff --git a/pkg/sentry/kernel/semaphore/semaphore_test.go b/pkg/sentry/kernel/semaphore/semaphore_test.go
index 2e4ab8121..59ac92ef1 100644
--- a/pkg/sentry/kernel/semaphore/semaphore_test.go
+++ b/pkg/sentry/kernel/semaphore/semaphore_test.go
@@ -19,10 +19,10 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/ipc"
- "gvisor.dev/gvisor/pkg/syserror"
)
func executeOps(ctx context.Context, t *testing.T, set *Set, ops []linux.Sembuf, block bool) chan struct{} {
@@ -124,14 +124,14 @@ func TestNoWait(t *testing.T) {
ops[0].SemOp = -2
ops[0].SemFlg = linux.IPC_NOWAIT
- if _, _, err := set.executeOps(ctx, ops, 123); err != syserror.ErrWouldBlock {
- t.Fatalf("ExecuteOps(ops) wrong result, got: %v, expected: %v", err, syserror.ErrWouldBlock)
+ if _, _, err := set.executeOps(ctx, ops, 123); err != linuxerr.ErrWouldBlock {
+ t.Fatalf("ExecuteOps(ops) wrong result, got: %v, expected: %v", err, linuxerr.ErrWouldBlock)
}
ops[0].SemOp = 0
ops[0].SemFlg = linux.IPC_NOWAIT
- if _, _, err := set.executeOps(ctx, ops, 123); err != syserror.ErrWouldBlock {
- t.Fatalf("ExecuteOps(ops) wrong result, got: %v, expected: %v", err, syserror.ErrWouldBlock)
+ if _, _, err := set.executeOps(ctx, ops, 123); err != linuxerr.ErrWouldBlock {
+ t.Fatalf("ExecuteOps(ops) wrong result, got: %v, expected: %v", err, linuxerr.ErrWouldBlock)
}
}
diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD
index 4e8deac4c..2547957ba 100644
--- a/pkg/sentry/kernel/shm/BUILD
+++ b/pkg/sentry/kernel/shm/BUILD
@@ -42,7 +42,6 @@ go_library(
"//pkg/sentry/pgalloc",
"//pkg/sentry/usage",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go
index 2abf467d7..ab938fa3c 100644
--- a/pkg/sentry/kernel/shm/shm.go
+++ b/pkg/sentry/kernel/shm/shm.go
@@ -49,7 +49,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
// Registry tracks all shared memory segments in an IPC namespace. The registry
@@ -151,7 +150,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key ipc.Key, siz
if r.reg.ObjectCount() >= linux.SHMMNI {
// "All possible shared memory IDs have been taken (SHMMNI) ..."
// - man shmget(2)
- return nil, syserror.ENOSPC
+ return nil, linuxerr.ENOSPC
}
if !private {
@@ -184,7 +183,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key ipc.Key, siz
// "... allocating a segment of the requested size would cause the
// system to exceed the system-wide limit on shared memory (SHMALL)."
// - man shmget(2)
- return nil, syserror.ENOSPC
+ return nil, linuxerr.ENOSPC
}
// Need to create a new segment.
@@ -521,7 +520,7 @@ func (s *Shm) ConfigureAttach(ctx context.Context, addr hostarch.Addr, opts Atta
s.mu.Lock()
defer s.mu.Unlock()
if s.pendingDestruction && s.ReadRefs() == 0 {
- return memmap.MMapOpts{}, syserror.EIDRM
+ return memmap.MMapOpts{}, linuxerr.EIDRM
}
creds := auth.CredentialsFromContext(ctx)
@@ -619,25 +618,10 @@ func (s *Shm) Set(ctx context.Context, ds *linux.ShmidDS) error {
s.mu.Lock()
defer s.mu.Unlock()
- creds := auth.CredentialsFromContext(ctx)
- if !s.obj.CheckOwnership(creds) {
- return linuxerr.EPERM
- }
-
- uid := creds.UserNamespace.MapToKUID(auth.UID(ds.ShmPerm.UID))
- gid := creds.UserNamespace.MapToKGID(auth.GID(ds.ShmPerm.GID))
- if !uid.Ok() || !gid.Ok() {
- return linuxerr.EINVAL
+ if err := s.obj.Set(ctx, &ds.ShmPerm); err != nil {
+ return err
}
- // User may only modify the lower 9 bits of the mode. All the other bits are
- // always 0 for the underlying inode.
- mode := linux.FileMode(ds.ShmPerm.Mode & 0x1ff)
- s.obj.Perms = fs.FilePermsFromMode(mode)
-
- s.obj.Owner.UID = uid
- s.obj.Owner.GID = gid
-
s.changeTime = ktime.NowFromContext(ctx)
return nil
}
diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD
index 1110ecca5..4180ca28e 100644
--- a/pkg/sentry/kernel/signalfd/BUILD
+++ b/pkg/sentry/kernel/signalfd/BUILD
@@ -15,7 +15,6 @@ go_library(
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/kernel",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
],
diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go
index 47958e2d4..9c5e6698c 100644
--- a/pkg/sentry/kernel/signalfd/signalfd.go
+++ b/pkg/sentry/kernel/signalfd/signalfd.go
@@ -24,7 +24,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -99,7 +98,7 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
info, err := s.target.Sigtimedwait(s.Mask(), 0)
if err != nil {
// There must be no signal available.
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
// Copy out the signal info using the specified format.
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index 59eeb253d..b0004482c 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -30,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/seccheck"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
@@ -510,9 +511,7 @@ type Task struct {
numaNodeMask uint64
// netns is the task's network namespace. netns is never nil.
- //
- // netns is protected by mu.
- netns *inet.Namespace
+ netns inet.NamespaceAtomicPtr
// If rseqPreempted is true, before the next call to p.Switch(),
// interrupt rseq critical regions as defined by rseqAddr and
@@ -874,3 +873,23 @@ func (t *Task) ResetKcov() {
t.kcov = nil
}
}
+
+// Preconditions: The TaskSet mutex must be locked.
+func (t *Task) loadSeccheckInfoLocked(req seccheck.TaskFieldSet, mask *seccheck.TaskFieldSet, info *seccheck.TaskInfo) {
+ if req.Contains(seccheck.TaskFieldThreadID) {
+ info.ThreadID = int32(t.k.tasks.Root.tids[t])
+ mask.Add(seccheck.TaskFieldThreadID)
+ }
+ if req.Contains(seccheck.TaskFieldThreadStartTime) {
+ info.ThreadStartTime = t.startTime
+ mask.Add(seccheck.TaskFieldThreadStartTime)
+ }
+ if req.Contains(seccheck.TaskFieldThreadGroupID) {
+ info.ThreadGroupID = int32(t.k.tasks.Root.tgids[t.tg])
+ mask.Add(seccheck.TaskFieldThreadGroupID)
+ }
+ if req.Contains(seccheck.TaskFieldThreadGroupStartTime) {
+ info.ThreadGroupStartTime = t.tg.leader.startTime
+ mask.Add(seccheck.TaskFieldThreadGroupStartTime)
+ }
+}
diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go
index b2520eecf..9bfc155e4 100644
--- a/pkg/sentry/kernel/task_block.go
+++ b/pkg/sentry/kernel/task_block.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/errors/linuxerr"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
// BlockWithTimeout blocks t until an event is received from C, the application
@@ -33,7 +32,7 @@ import (
// and is unspecified if haveTimeout is false.
//
// - An error which is nil if an event is received from C, ETIMEDOUT if the timeout
-// expired, and syserror.ErrInterrupted if t is interrupted.
+// expired, and linuxerr.ErrInterrupted if t is interrupted.
//
// Preconditions: The caller must be running on the task goroutine.
func (t *Task) BlockWithTimeout(C chan struct{}, haveTimeout bool, timeout time.Duration) (time.Duration, error) {
@@ -67,7 +66,7 @@ func (t *Task) BlockWithTimeout(C chan struct{}, haveTimeout bool, timeout time.
// application monotonic clock indicates a time of deadline (only if
// haveDeadline is true), or t is interrupted. It returns nil if an event is
// received from C, ETIMEDOUT if the deadline expired, and
-// syserror.ErrInterrupted if t is interrupted.
+// linuxerr.ErrInterrupted if t is interrupted.
//
// Preconditions: The caller must be running on the task goroutine.
func (t *Task) BlockWithDeadline(C <-chan struct{}, haveDeadline bool, deadline ktime.Time) error {
@@ -95,7 +94,7 @@ func (t *Task) BlockWithDeadline(C <-chan struct{}, haveDeadline bool, deadline
// BlockWithTimer blocks t until an event is received from C or tchan, or t is
// interrupted. It returns nil if an event is received from C, ETIMEDOUT if an
-// event is received from tchan, and syserror.ErrInterrupted if t is
+// event is received from tchan, and linuxerr.ErrInterrupted if t is
// interrupted.
//
// Most clients should use BlockWithDeadline or BlockWithTimeout instead.
@@ -106,7 +105,7 @@ func (t *Task) BlockWithTimer(C <-chan struct{}, tchan <-chan struct{}) error {
}
// Block blocks t until an event is received from C or t is interrupted. It
-// returns nil if an event is received from C and syserror.ErrInterrupted if t
+// returns nil if an event is received from C and linuxerr.ErrInterrupted if t
// is interrupted.
//
// Preconditions: The caller must be running on the task goroutine.
@@ -157,7 +156,7 @@ func (t *Task) block(C <-chan struct{}, timerChan <-chan struct{}) error {
region.End()
t.SleepFinish(false)
// Return the indicated error on interrupt.
- return syserror.ErrInterrupted
+ return linuxerr.ErrInterrupted
case <-timerChan:
region.End()
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index da4b77ca2..a6d8fb163 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/seccheck"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -235,7 +236,23 @@ func (t *Task) Clone(args *linux.CloneArgs) (ThreadID, *SyscallControl, error) {
// nt that it must receive before its task goroutine starts running.
tid := nt.k.tasks.Root.IDOfTask(nt)
defer nt.Start(tid)
- t.traceCloneEvent(tid)
+
+ if seccheck.Global.Enabled(seccheck.PointClone) {
+ mask, info := getCloneSeccheckInfo(t, nt, args)
+ if err := seccheck.Global.Clone(t, mask, &info); err != nil {
+ // nt has been visible to the rest of the system since NewTask, so
+ // it may be blocking execve or a group stop, have been notified
+ // for group signal delivery, had children reparented to it, etc.
+ // Thus we can't just drop it on the floor. Instead, instruct the
+ // task goroutine to exit immediately, as quietly as possible.
+ nt.exitTracerNotified = true
+ nt.exitTracerAcked = true
+ nt.exitParentNotified = true
+ nt.exitParentAcked = true
+ nt.runState = (*runExitMain)(nil)
+ return 0, nil, err
+ }
+ }
// "If fork/clone and execve are allowed by @prog, any child processes will
// be constrained to the same filters and system call ABI as the parent." -
@@ -260,6 +277,7 @@ func (t *Task) Clone(args *linux.CloneArgs) (ThreadID, *SyscallControl, error) {
ntid.CopyOut(t, hostarch.Addr(args.ParentTID))
}
+ t.traceCloneEvent(tid)
kind := ptraceCloneKindClone
if args.Flags&linux.CLONE_VFORK != 0 {
kind = ptraceCloneKindVfork
@@ -279,6 +297,22 @@ func (t *Task) Clone(args *linux.CloneArgs) (ThreadID, *SyscallControl, error) {
return ntid, nil, nil
}
+func getCloneSeccheckInfo(t, nt *Task, args *linux.CloneArgs) (seccheck.CloneFieldSet, seccheck.CloneInfo) {
+ req := seccheck.Global.CloneReq()
+ info := seccheck.CloneInfo{
+ Credentials: t.Credentials(),
+ Args: *args,
+ }
+ var mask seccheck.CloneFieldSet
+ mask.Add(seccheck.CloneFieldCredentials)
+ mask.Add(seccheck.CloneFieldArgs)
+ t.k.tasks.mu.RLock()
+ defer t.k.tasks.mu.RUnlock()
+ t.loadSeccheckInfoLocked(req.Invoker, &mask.Invoker, &info.Invoker)
+ nt.loadSeccheckInfoLocked(req.Created, &mask.Created, &info.Created)
+ return mask, info
+}
+
// maybeBeginVforkStop checks if a previously-started vfork child is still
// running and has not yet released its MM, such that its parent t should enter
// a vforkStop.
@@ -410,7 +444,7 @@ func (t *Task) Unshare(flags int32) error {
t.mu.Unlock()
return linuxerr.EPERM
}
- t.netns = inet.NewNamespace(t.netns)
+ t.netns.Store(inet.NewNamespace(t.netns.Load()))
}
if flags&linux.CLONE_NEWUTS != 0 {
if !haveCapSysAdmin {
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
index cf8571262..db91fc4d8 100644
--- a/pkg/sentry/kernel/task_exec.go
+++ b/pkg/sentry/kernel/task_exec.go
@@ -66,10 +66,10 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
)
// execStop is a TaskStop that a task sets on itself when it wants to execve
@@ -97,7 +97,7 @@ func (t *Task) Execve(newImage *TaskImage) (*SyscallControl, error) {
// We lost to a racing group-exit, kill, or exec from another thread
// and should just exit.
newImage.release()
- return nil, syserror.EINTR
+ return nil, linuxerr.EINTR
}
// Cancel any racing group stops.
@@ -222,9 +222,15 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
// Update credentials to reflect the execve. This should precede switching
// MMs to ensure that dumpability has been reset first, if needed.
t.updateCredsForExecLocked()
- t.image.release()
+ oldImage := t.image
t.image = *r.image
t.mu.Unlock()
+
+ // Don't hold t.mu while calling t.image.release(), that may
+ // attempt to acquire TaskImage.MemoryManager.mappingMu, a lock order
+ // violation.
+ oldImage.release()
+
t.unstopVforkParent()
t.p.FullStateChanged()
// NOTE(b/30316266): All locks must be dropped prior to calling Activate.
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index fbfcc19e5..b3931445b 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -32,7 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -230,9 +230,16 @@ func (*runExitMain) execute(t *Task) taskRunState {
t.tg.pidns.owner.mu.Lock()
t.updateRSSLocked()
t.tg.pidns.owner.mu.Unlock()
+
+ // Release the task image resources. Accessing these fields must be
+ // done with t.mu held, but the mm.DecUsers() call must be done outside
+ // of that lock.
t.mu.Lock()
- t.image.release()
+ mm := t.image.MemoryManager
+ t.image.MemoryManager = nil
+ t.image.fu = nil
t.mu.Unlock()
+ mm.DecUsers(t)
// Releasing the MM unblocks a blocked CLONE_VFORK parent.
t.unstopVforkParent()
@@ -859,7 +866,7 @@ func (t *Task) Wait(opts *WaitOptions) (*WaitResult, error) {
return wr, err
}
if err := t.Block(ch); err != nil {
- return wr, syserror.ConvertIntr(err, opts.BlockInterruptErr)
+ return wr, syserr.ConvertIntr(err, opts.BlockInterruptErr)
}
}
}
diff --git a/pkg/sentry/kernel/task_image.go b/pkg/sentry/kernel/task_image.go
index c132c27ef..6002ffb42 100644
--- a/pkg/sentry/kernel/task_image.go
+++ b/pkg/sentry/kernel/task_image.go
@@ -53,7 +53,7 @@ type TaskImage struct {
}
// release releases all resources held by the TaskImage. release is called by
-// the task when it execs into a new TaskImage or exits.
+// the task when it execs into a new TaskImage.
func (image *TaskImage) release() {
// Nil out pointers so that if the task is saved after release, it doesn't
// follow the pointers to possibly now-invalid objects.
diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go
index 8de08151a..f0c168ecc 100644
--- a/pkg/sentry/kernel/task_log.go
+++ b/pkg/sentry/kernel/task_log.go
@@ -191,9 +191,11 @@ const (
//
// Preconditions: The task's owning TaskSet.mu must be locked.
func (t *Task) updateInfoLocked() {
- // Use the task's TID in the root PID namespace for logging.
+ // Use the task's TID and PID in the root PID namespace for logging.
+ pid := t.tg.pidns.owner.Root.tgids[t.tg]
tid := t.tg.pidns.owner.Root.tids[t]
- t.logPrefix.Store(fmt.Sprintf("[% 4d] ", tid))
+ t.logPrefix.Store(fmt.Sprintf("[% 4d:% 4d] ", pid, tid))
+
t.rebuildTraceContext(tid)
}
@@ -249,5 +251,9 @@ func (t *Task) traceExecEvent(image *TaskImage) {
return
}
defer file.DecRef(t)
- trace.Logf(t.traceContext, traceCategory, "exec: %s", file.PathnameWithDeleted(t))
+
+ // traceExecEvent function may be called before the task goroutine
+ // starts, so we must use the async context.
+ name := file.PathnameWithDeleted(t.AsyncContext())
+ trace.Logf(t.traceContext, traceCategory, "exec: %s", name)
}
diff --git a/pkg/sentry/kernel/task_net.go b/pkg/sentry/kernel/task_net.go
index f7711232c..e31e2b2e8 100644
--- a/pkg/sentry/kernel/task_net.go
+++ b/pkg/sentry/kernel/task_net.go
@@ -20,9 +20,7 @@ import (
// IsNetworkNamespaced returns true if t is in a non-root network namespace.
func (t *Task) IsNetworkNamespaced() bool {
- t.mu.Lock()
- defer t.mu.Unlock()
- return !t.netns.IsRoot()
+ return !t.netns.Load().IsRoot()
}
// NetworkContext returns the network stack used by the task. NetworkContext
@@ -31,14 +29,10 @@ func (t *Task) IsNetworkNamespaced() bool {
// TODO(gvisor.dev/issue/1833): Migrate callers of this method to
// NetworkNamespace().
func (t *Task) NetworkContext() inet.Stack {
- t.mu.Lock()
- defer t.mu.Unlock()
- return t.netns.Stack()
+ return t.netns.Load().Stack()
}
// NetworkNamespace returns the network namespace observed by the task.
func (t *Task) NetworkNamespace() *inet.Namespace {
- t.mu.Lock()
- defer t.mu.Unlock()
- return t.netns
+ return t.netns.Load()
}
diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go
index 054ff212f..7b336a46b 100644
--- a/pkg/sentry/kernel/task_run.go
+++ b/pkg/sentry/kernel/task_run.go
@@ -22,6 +22,7 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/goid"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -29,7 +30,6 @@ import (
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/pkg/syserror"
)
// A taskRunState is a reified state in the task state machine. See README.md
@@ -197,8 +197,8 @@ func (app *runApp) execute(t *Task) taskRunState {
// a pending signal, causing another interruption, but that signal should
// not interact with the interrupted syscall.)
if t.haveSyscallReturn {
- if sre, ok := syserror.SyscallRestartErrnoFromReturn(t.Arch().Return()); ok {
- if sre == syserror.ERESTART_RESTARTBLOCK {
+ if sre, ok := linuxerr.SyscallRestartErrorFromReturn(t.Arch().Return()); ok {
+ if sre == linuxerr.ERESTART_RESTARTBLOCK {
t.Debugf("Restarting syscall %d with restart block after errno %d: not interrupted by handled signal", t.Arch().SyscallNo(), sre)
t.Arch().RestartSyscallWithRestartBlock()
} else {
diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go
index 7065ac79c..eeb3c5e69 100644
--- a/pkg/sentry/kernel/task_signals.go
+++ b/pkg/sentry/kernel/task_signals.go
@@ -28,7 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -161,7 +160,7 @@ func (t *Task) deliverSignal(info *linux.SignalInfo, act linux.SigAction) taskRu
sigact := computeAction(sig, act)
if t.haveSyscallReturn {
- if sre, ok := syserror.SyscallRestartErrnoFromReturn(t.Arch().Return()); ok {
+ if sre, ok := linuxerr.SyscallRestartErrorFromReturn(t.Arch().Return()); ok {
// Signals that are ignored, cause a thread group stop, or
// terminate the thread group do not interact with interrupted
// syscalls; in Linux terms, they are never returned to the signal
@@ -170,13 +169,13 @@ func (t *Task) deliverSignal(info *linux.SignalInfo, act linux.SigAction) taskRu
// signal that is actually handled (by userspace).
if sigact == SignalActionHandler {
switch {
- case sre == syserror.ERESTARTNOHAND:
+ case sre == linuxerr.ERESTARTNOHAND:
fallthrough
- case sre == syserror.ERESTART_RESTARTBLOCK:
+ case sre == linuxerr.ERESTART_RESTARTBLOCK:
fallthrough
- case (sre == syserror.ERESTARTSYS && act.Flags&linux.SA_RESTART == 0):
+ case (sre == linuxerr.ERESTARTSYS && act.Flags&linux.SA_RESTART == 0):
t.Debugf("Not restarting syscall %d after errno %d: interrupted by signal %d", t.Arch().SyscallNo(), sre, info.Signo)
- t.Arch().SetReturn(uintptr(-ExtractErrno(syserror.EINTR, -1)))
+ t.Arch().SetReturn(uintptr(-ExtractErrno(linuxerr.EINTR, -1)))
default:
t.Debugf("Restarting syscall %d after errno %d: interrupted by signal %d", t.Arch().SyscallNo(), sre, info.Signo)
t.Arch().RestartSyscall()
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index 0565059c1..4919dea7c 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -25,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
)
// TaskConfig defines the configuration of a new Task (see below).
@@ -141,7 +140,6 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
allowedCPUMask: cfg.AllowedCPUMask.Copy(),
ioUsage: &usage.IO{},
niceness: cfg.Niceness,
- netns: cfg.NetworkNamespace,
utsns: cfg.UTSNamespace,
ipcns: cfg.IPCNamespace,
abstractSockets: cfg.AbstractSocketNamespace,
@@ -153,6 +151,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
containerID: cfg.ContainerID,
cgroups: make(map[Cgroup]struct{}),
}
+ t.netns.Store(cfg.NetworkNamespace)
t.creds.Store(cfg.Credentials)
t.endStopCond.L = &t.tg.signalHandlers.mu
t.ptraceTracer.Store((*Task)(nil))
@@ -170,7 +169,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
// doesn't matter too much since the caller will exit before it returns
// to userspace. If the caller isn't in the same thread group, then
// we're in uncharted territory and can return whatever we want.
- return nil, syserror.EINTR
+ return nil, linuxerr.EINTR
}
if err := ts.assignTIDsLocked(t); err != nil {
return nil, err
@@ -268,7 +267,7 @@ func (ns *PIDNamespace) allocateTID() (ThreadID, error) {
// fail with the error ENOMEM; it is not possible to create a new
// processes [sic] in a PID namespace whose init process has
// terminated." - pid_namespaces(7)
- return 0, syserror.ENOMEM
+ return 0, linuxerr.ENOMEM
}
tid := ns.last
for {
diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go
index 0586c9def..2b1d7e114 100644
--- a/pkg/sentry/kernel/task_syscall.go
+++ b/pkg/sentry/kernel/task_syscall.go
@@ -29,7 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/syserror"
)
// SyscallRestartBlock represents the restart block for a syscall restartable
@@ -383,8 +382,6 @@ func ExtractErrno(err error, sysno int) int {
return int(err)
case *errors.Error:
return int(err.Errno())
- case syserror.SyscallRestartErrno:
- return int(err)
case *memmap.BusError:
// Bus errors may generate SIGBUS, but for syscalls they still
// return EFAULT. See case in task_run.go where the fault is
@@ -397,8 +394,8 @@ func ExtractErrno(err error, sysno int) int {
case *os.SyscallError:
return ExtractErrno(err.Err, sysno)
default:
- if errno, ok := syserror.TranslateError(err); ok {
- return int(errno)
+ if errno, ok := linuxerr.TranslateError(err); ok {
+ return int(errno.Errno())
}
}
panic(fmt.Sprintf("Unknown syscall %d error: %v", sysno, err))
diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go
index 8e2c36598..bff226a11 100644
--- a/pkg/sentry/kernel/task_usermem.go
+++ b/pkg/sentry/kernel/task_usermem.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/mm"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -105,7 +104,7 @@ func (t *Task) CopyInVector(addr hostarch.Addr, maxElemSize, maxTotalSize int) (
// Each string has a zero terminating byte counted, so copying out a string
// requires at least one byte of space. Also, see the calculation below.
if maxTotalSize <= 0 {
- return nil, syserror.ENOMEM
+ return nil, linuxerr.ENOMEM
}
thisMax := maxElemSize
if maxTotalSize < thisMax {
@@ -148,7 +147,7 @@ func (t *Task) CopyOutIovecs(addr hostarch.Addr, src hostarch.AddrRangeSeq) erro
}
default:
- return syserror.ENOSYS
+ return linuxerr.ENOSYS
}
return nil
@@ -220,7 +219,7 @@ func (t *Task) CopyInIovecs(addr hostarch.Addr, numIovecs int) (hostarch.AddrRan
}
default:
- return hostarch.AddrRangeSeq{}, syserror.ENOSYS
+ return hostarch.AddrRangeSeq{}, linuxerr.ENOSYS
}
// Truncate to MAX_RW_COUNT.
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index 2eda15303..5814a4eca 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -489,11 +489,6 @@ func (tg *ThreadGroup) SetForegroundProcessGroup(tty *TTY, pgid ProcessGroupID)
tg.signalHandlers.mu.Lock()
defer tg.signalHandlers.mu.Unlock()
- // TODO(gvisor.dev/issue/6148): "If tcsetpgrp() is called by a member of a
- // background process group in its session, and the calling process is not
- // blocking or ignoring SIGTTOU, a SIGTTOU signal is sent to all members of
- // this background process group."
-
// tty must be the controlling terminal.
if tg.tty != tty {
return -1, linuxerr.ENOTTY
@@ -516,6 +511,16 @@ func (tg *ThreadGroup) SetForegroundProcessGroup(tty *TTY, pgid ProcessGroupID)
return -1, linuxerr.EPERM
}
+ signalAction := tg.signalHandlers.actions[linux.SIGTTOU]
+ // If the calling process is a member of a background group, a SIGTTOU
+ // signal is sent to all members of this background process group.
+ // We need also need to check whether it is ignoring or blocking SIGTTOU.
+ ignored := signalAction.Handler == linux.SIG_IGN
+ blocked := tg.leader.signalMask == linux.SignalSetOf(linux.SIGTTOU)
+ if tg.processGroup.id != tg.processGroup.session.foreground.id && !ignored && !blocked {
+ tg.leader.sendSignalLocked(SignalInfoPriv(linux.SIGTTOU), true)
+ }
+
tg.processGroup.session.foreground.id = pgid
return 0, nil
}
diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD
index 54bfed644..560a0f33c 100644
--- a/pkg/sentry/loader/BUILD
+++ b/pkg/sentry/loader/BUILD
@@ -37,7 +37,6 @@ go_library(
"//pkg/sentry/usage",
"//pkg/sentry/vfs",
"//pkg/syserr",
- "//pkg/syserror",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go
index 577374fa4..fb213d109 100644
--- a/pkg/sentry/loader/elf.go
+++ b/pkg/sentry/loader/elf.go
@@ -32,7 +32,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/mm"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -116,7 +115,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) {
log.Infof("Error reading ELF ident: %v", err)
// The entire ident array always exists.
if err == io.EOF || err == io.ErrUnexpectedEOF {
- err = syserror.ENOEXEC
+ err = linuxerr.ENOEXEC
}
return elfInfo{}, err
}
@@ -124,22 +123,22 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) {
// Only some callers pre-check the ELF magic.
if !bytes.Equal(ident[:len(elfMagic)], []byte(elfMagic)) {
log.Infof("File is not an ELF")
- return elfInfo{}, syserror.ENOEXEC
+ return elfInfo{}, linuxerr.ENOEXEC
}
// We only support 64-bit, little endian binaries
if class := elf.Class(ident[elf.EI_CLASS]); class != elf.ELFCLASS64 {
log.Infof("Unsupported ELF class: %v", class)
- return elfInfo{}, syserror.ENOEXEC
+ return elfInfo{}, linuxerr.ENOEXEC
}
if endian := elf.Data(ident[elf.EI_DATA]); endian != elf.ELFDATA2LSB {
log.Infof("Unsupported ELF endianness: %v", endian)
- return elfInfo{}, syserror.ENOEXEC
+ return elfInfo{}, linuxerr.ENOEXEC
}
if version := elf.Version(ident[elf.EI_VERSION]); version != elf.EV_CURRENT {
log.Infof("Unsupported ELF version: %v", version)
- return elfInfo{}, syserror.ENOEXEC
+ return elfInfo{}, linuxerr.ENOEXEC
}
// EI_OSABI is ignored by Linux, which is the only OS supported.
os := abi.Linux
@@ -151,7 +150,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) {
log.Infof("Error reading ELF header: %v", err)
// The entire header always exists.
if err == io.EOF || err == io.ErrUnexpectedEOF {
- err = syserror.ENOEXEC
+ err = linuxerr.ENOEXEC
}
return elfInfo{}, err
}
@@ -166,7 +165,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) {
a = arch.ARM64
default:
log.Infof("Unsupported ELF machine %d", machine)
- return elfInfo{}, syserror.ENOEXEC
+ return elfInfo{}, linuxerr.ENOEXEC
}
var sharedObject bool
@@ -178,25 +177,25 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) {
sharedObject = true
default:
log.Infof("Unsupported ELF type %v", elfType)
- return elfInfo{}, syserror.ENOEXEC
+ return elfInfo{}, linuxerr.ENOEXEC
}
if int(hdr.Phentsize) != prog64Size {
log.Infof("Unsupported phdr size %d", hdr.Phentsize)
- return elfInfo{}, syserror.ENOEXEC
+ return elfInfo{}, linuxerr.ENOEXEC
}
totalPhdrSize := prog64Size * int(hdr.Phnum)
if totalPhdrSize < prog64Size {
log.Warningf("No phdrs or total phdr size overflows: prog64Size: %d phnum: %d", prog64Size, int(hdr.Phnum))
- return elfInfo{}, syserror.ENOEXEC
+ return elfInfo{}, linuxerr.ENOEXEC
}
if totalPhdrSize > maxTotalPhdrSize {
log.Infof("Too many phdrs (%d): total size %d > %d", hdr.Phnum, totalPhdrSize, maxTotalPhdrSize)
- return elfInfo{}, syserror.ENOEXEC
+ return elfInfo{}, linuxerr.ENOEXEC
}
if int64(hdr.Phoff) < 0 || int64(hdr.Phoff+uint64(totalPhdrSize)) < 0 {
ctx.Infof("Unsupported phdr offset %d", hdr.Phoff)
- return elfInfo{}, syserror.ENOEXEC
+ return elfInfo{}, linuxerr.ENOEXEC
}
phdrBuf := make([]byte, totalPhdrSize)
@@ -205,7 +204,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) {
log.Infof("Error reading ELF phdrs: %v", err)
// If phdrs were specified, they should all exist.
if err == io.EOF || err == io.ErrUnexpectedEOF {
- err = syserror.ENOEXEC
+ err = linuxerr.ENOEXEC
}
return elfInfo{}, err
}
@@ -248,19 +247,19 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr
if !ok {
// If offset != 0 we should have ensured this would fit.
ctx.Warningf("Computed segment load address overflows: %#x + %#x", phdr.Vaddr, offset)
- return syserror.ENOEXEC
+ return linuxerr.ENOEXEC
}
addr -= hostarch.Addr(adjust)
fileSize := phdr.Filesz + adjust
if fileSize < phdr.Filesz {
ctx.Infof("Computed segment file size overflows: %#x + %#x", phdr.Filesz, adjust)
- return syserror.ENOEXEC
+ return linuxerr.ENOEXEC
}
ms, ok := hostarch.Addr(fileSize).RoundUp()
if !ok {
ctx.Infof("fileSize %#x too large", fileSize)
- return syserror.ENOEXEC
+ return linuxerr.ENOEXEC
}
mapSize := uint64(ms)
@@ -321,7 +320,7 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr
memSize := phdr.Memsz + adjust
if memSize < phdr.Memsz {
ctx.Infof("Computed segment mem size overflows: %#x + %#x", phdr.Memsz, adjust)
- return syserror.ENOEXEC
+ return linuxerr.ENOEXEC
}
// Allocate more anonymous pages if necessary.
@@ -333,7 +332,7 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr
anonSize, ok := hostarch.Addr(memSize - mapSize).RoundUp()
if !ok {
ctx.Infof("extra anon pages too large: %#x", memSize-mapSize)
- return syserror.ENOEXEC
+ return linuxerr.ENOEXEC
}
// N.B. Linux uses vm_brk_flags to map these pages, which only
@@ -423,27 +422,27 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, in
// NOTE(b/37474556): Linux allows out-of-order
// segments, in violation of the spec.
ctx.Infof("PT_LOAD headers out-of-order. %#x < %#x", vaddr, end)
- return loadedELF{}, syserror.ENOEXEC
+ return loadedELF{}, linuxerr.ENOEXEC
}
var ok bool
end, ok = vaddr.AddLength(phdr.Memsz)
if !ok {
ctx.Infof("PT_LOAD header size overflows. %#x + %#x", vaddr, phdr.Memsz)
- return loadedELF{}, syserror.ENOEXEC
+ return loadedELF{}, linuxerr.ENOEXEC
}
case elf.PT_INTERP:
if phdr.Filesz < 2 {
ctx.Infof("PT_INTERP path too small: %v", phdr.Filesz)
- return loadedELF{}, syserror.ENOEXEC
+ return loadedELF{}, linuxerr.ENOEXEC
}
if phdr.Filesz > linux.PATH_MAX {
ctx.Infof("PT_INTERP path too big: %v", phdr.Filesz)
- return loadedELF{}, syserror.ENOEXEC
+ return loadedELF{}, linuxerr.ENOEXEC
}
if int64(phdr.Off) < 0 || int64(phdr.Off+phdr.Filesz) < 0 {
ctx.Infof("Unsupported PT_INTERP offset %d", phdr.Off)
- return loadedELF{}, syserror.ENOEXEC
+ return loadedELF{}, linuxerr.ENOEXEC
}
path := make([]byte, phdr.Filesz)
@@ -451,12 +450,12 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, in
if err != nil {
// If an interpreter was specified, it should exist.
ctx.Infof("Error reading PT_INTERP path: %v", err)
- return loadedELF{}, syserror.ENOEXEC
+ return loadedELF{}, linuxerr.ENOEXEC
}
if path[len(path)-1] != 0 {
ctx.Infof("PT_INTERP path not NUL-terminated: %v", path)
- return loadedELF{}, syserror.ENOEXEC
+ return loadedELF{}, linuxerr.ENOEXEC
}
// Strip NUL-terminator and everything beyond from
@@ -498,7 +497,7 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, in
totalSize, ok := totalSize.RoundUp()
if !ok {
ctx.Infof("ELF PT_LOAD segments too big")
- return loadedELF{}, syserror.ENOEXEC
+ return loadedELF{}, linuxerr.ENOEXEC
}
var err error
@@ -592,7 +591,7 @@ func loadInitialELF(ctx context.Context, m *mm.MemoryManager, fs *cpuid.FeatureS
// Check Image Compatibility.
if arch.Host != info.arch {
ctx.Warningf("Found mismatch for platform %s with ELF type %s", arch.Host.String(), info.arch.String())
- return loadedELF{}, nil, syserror.ENOEXEC
+ return loadedELF{}, nil, linuxerr.ENOEXEC
}
// Create the arch.Context now so we can prepare the mmap layout before
@@ -681,7 +680,7 @@ func loadELF(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, error
if interp.interpreter != "" {
// No recursive interpreters!
ctx.Infof("Interpreter requires an interpreter")
- return loadedELF{}, nil, syserror.ENOEXEC
+ return loadedELF{}, nil, linuxerr.ENOEXEC
}
}
diff --git a/pkg/sentry/loader/interpreter.go b/pkg/sentry/loader/interpreter.go
index 3e302d92c..1ec0d7019 100644
--- a/pkg/sentry/loader/interpreter.go
+++ b/pkg/sentry/loader/interpreter.go
@@ -19,8 +19,8 @@ import (
"io"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fsbridge"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -43,14 +43,14 @@ func parseInterpreterScript(ctx context.Context, filename string, f fsbridge.Fil
// Short read is OK.
if err != nil && err != io.ErrUnexpectedEOF {
if err == io.EOF {
- err = syserror.ENOEXEC
+ err = linuxerr.ENOEXEC
}
return "", []string{}, err
}
line = line[:n]
if !bytes.Equal(line[:2], []byte(interpreterScriptMagic)) {
- return "", []string{}, syserror.ENOEXEC
+ return "", []string{}, linuxerr.ENOEXEC
}
// Ignore #!.
line = line[2:]
@@ -82,7 +82,7 @@ func parseInterpreterScript(ctx context.Context, filename string, f fsbridge.Fil
if string(interp) == "" {
ctx.Infof("Interpreter script contains no interpreter: %v", line)
- return "", []string{}, syserror.ENOEXEC
+ return "", []string{}, linuxerr.ENOEXEC
}
// Build the new argument list:
diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go
index 86d0c54cd..2759ef71e 100644
--- a/pkg/sentry/loader/loader.go
+++ b/pkg/sentry/loader/loader.go
@@ -35,7 +35,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -91,7 +90,7 @@ type LoadArgs struct {
func openPath(ctx context.Context, args LoadArgs) (fsbridge.File, error) {
if args.Filename == "" {
ctx.Infof("cannot open empty name")
- return nil, syserror.ENOENT
+ return nil, linuxerr.ENOENT
}
// TODO(gvisor.dev/issue/160): Linux requires only execute permission,
@@ -172,7 +171,7 @@ func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context
// (e.g., #!a).
if err != nil && err != io.ErrUnexpectedEOF {
if err == io.EOF {
- err = syserror.ENOEXEC
+ err = linuxerr.ENOEXEC
}
return loadedELF{}, nil, nil, nil, err
}
@@ -190,7 +189,7 @@ func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context
case bytes.Equal(hdr[:2], []byte(interpreterScriptMagic)):
if args.CloseOnExec {
- return loadedELF{}, nil, nil, nil, syserror.ENOENT
+ return loadedELF{}, nil, nil, nil, linuxerr.ENOENT
}
args.Filename, args.Argv, err = parseInterpreterScript(ctx, args.Filename, args.File, args.Argv)
if err != nil {
@@ -202,7 +201,7 @@ func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context
default:
ctx.Infof("Unknown magic: %v", hdr)
- return loadedELF{}, nil, nil, nil, syserror.ENOEXEC
+ return loadedELF{}, nil, nil, nil, linuxerr.ENOEXEC
}
// Set to nil in case we loop on a Interpreter Script.
args.File = nil
@@ -296,15 +295,7 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V
m.SetEnvvEnd(sl.EnvvEnd)
m.SetAuxv(auxv)
m.SetExecutable(ctx, file)
-
- symbolValue, err := getSymbolValueFromVDSO("rt_sigreturn")
- if err != nil {
- return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to find rt_sigreturn in vdso: %v", err), syserr.FromError(err).ToLinux())
- }
-
- // Found rt_sigretrun.
- addr := uint64(vdsoAddr) + symbolValue - vdsoPrelink
- m.SetVDSOSigReturn(addr)
+ m.SetVDSOSigReturn(uint64(vdsoAddr) + vdsoSigreturnOffset - vdsoPrelink)
ac.SetIP(uintptr(loaded.entry))
ac.SetStack(uintptr(stack.Bottom))
diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go
index 054ef1723..bcee6aef6 100644
--- a/pkg/sentry/loader/vdso.go
+++ b/pkg/sentry/loader/vdso.go
@@ -19,7 +19,6 @@ import (
"debug/elf"
"fmt"
"io"
- "strings"
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/context"
@@ -34,7 +33,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -102,14 +100,14 @@ func validateVDSO(ctx context.Context, f fullReader, size uint64) (elfInfo, erro
first = &info.phdrs[i]
if phdr.Off != 0 {
log.Warningf("First PT_LOAD segment has non-zero file offset")
- return elfInfo{}, syserror.ENOEXEC
+ return elfInfo{}, linuxerr.ENOEXEC
}
}
memoryOffset := phdr.Vaddr - first.Vaddr
if memoryOffset != phdr.Off {
log.Warningf("PT_LOAD segment memory offset %#x != file offset %#x", memoryOffset, phdr.Off)
- return elfInfo{}, syserror.ENOEXEC
+ return elfInfo{}, linuxerr.ENOEXEC
}
// memsz larger than filesz means that extra zeroed space should be
@@ -118,24 +116,24 @@ func validateVDSO(ctx context.Context, f fullReader, size uint64) (elfInfo, erro
// zeroes.
if phdr.Memsz != phdr.Filesz {
log.Warningf("PT_LOAD segment memsz %#x != filesz %#x", phdr.Memsz, phdr.Filesz)
- return elfInfo{}, syserror.ENOEXEC
+ return elfInfo{}, linuxerr.ENOEXEC
}
start := hostarch.Addr(memoryOffset)
end, ok := start.AddLength(phdr.Memsz)
if !ok {
log.Warningf("PT_LOAD segment size overflows: %#x + %#x", start, end)
- return elfInfo{}, syserror.ENOEXEC
+ return elfInfo{}, linuxerr.ENOEXEC
}
if uint64(end) > size {
log.Warningf("PT_LOAD segment end %#x extends beyond end of file %#x", end, size)
- return elfInfo{}, syserror.ENOEXEC
+ return elfInfo{}, linuxerr.ENOEXEC
}
if prev != nil {
if start < prevEnd {
log.Warningf("PT_LOAD segments out of order")
- return elfInfo{}, syserror.ENOEXEC
+ return elfInfo{}, linuxerr.ENOEXEC
}
// We mprotect entire pages, so each segment must be in
@@ -144,7 +142,7 @@ func validateVDSO(ctx context.Context, f fullReader, size uint64) (elfInfo, erro
startPage := start.RoundDown()
if prevEndPage >= startPage {
log.Warningf("PT_LOAD segments share a page: %#x", prevEndPage)
- return elfInfo{}, syserror.ENOEXEC
+ return elfInfo{}, linuxerr.ENOEXEC
}
}
prev = &info.phdrs[i]
@@ -178,27 +176,6 @@ type VDSO struct {
phdrs []elf.ProgHeader `state:".([]elfProgHeader)"`
}
-// getSymbolValueFromVDSO returns the specific symbol value in vdso.so.
-func getSymbolValueFromVDSO(symbol string) (uint64, error) {
- f, err := elf.NewFile(bytes.NewReader(vdsodata.Binary))
- if err != nil {
- return 0, err
- }
- syms, err := f.Symbols()
- if err != nil {
- return 0, err
- }
-
- for _, sym := range syms {
- if elf.ST_BIND(sym.Info) != elf.STB_LOCAL && sym.Section != elf.SHN_UNDEF {
- if strings.Contains(sym.Name, symbol) {
- return sym.Value, nil
- }
- }
- }
- return 0, fmt.Errorf("no %v in vdso.so", symbol)
-}
-
// PrepareVDSO validates the system VDSO and returns a VDSO, containing the
// param page for updating by the kernel.
func PrepareVDSO(mfp pgalloc.MemoryFileProvider) (*VDSO, error) {
@@ -271,11 +248,11 @@ func PrepareVDSO(mfp pgalloc.MemoryFileProvider) (*VDSO, error) {
func loadVDSO(ctx context.Context, m *mm.MemoryManager, v *VDSO, bin loadedELF) (hostarch.Addr, error) {
if v.os != bin.os {
ctx.Warningf("Binary ELF OS %v and VDSO ELF OS %v differ", bin.os, v.os)
- return 0, syserror.ENOEXEC
+ return 0, linuxerr.ENOEXEC
}
if v.arch != bin.arch {
ctx.Warningf("Binary ELF arch %v and VDSO ELF arch %v differ", bin.arch, v.arch)
- return 0, syserror.ENOEXEC
+ return 0, linuxerr.ENOEXEC
}
// Reserve address space for the VDSO and its parameter page, which is
@@ -348,35 +325,35 @@ func loadVDSO(ctx context.Context, m *mm.MemoryManager, v *VDSO, bin loadedELF)
segAddr, ok := vdsoAddr.AddLength(memoryOffset)
if !ok {
ctx.Warningf("PT_LOAD segment address overflows: %#x + %#x", segAddr, memoryOffset)
- return 0, syserror.ENOEXEC
+ return 0, linuxerr.ENOEXEC
}
segPage := segAddr.RoundDown()
segSize := hostarch.Addr(phdr.Memsz)
segSize, ok = segSize.AddLength(segAddr.PageOffset())
if !ok {
ctx.Warningf("PT_LOAD segment memsize %#x + offset %#x overflows", phdr.Memsz, segAddr.PageOffset())
- return 0, syserror.ENOEXEC
+ return 0, linuxerr.ENOEXEC
}
segSize, ok = segSize.RoundUp()
if !ok {
ctx.Warningf("PT_LOAD segment size overflows: %#x", phdr.Memsz+segAddr.PageOffset())
- return 0, syserror.ENOEXEC
+ return 0, linuxerr.ENOEXEC
}
segEnd, ok := segPage.AddLength(uint64(segSize))
if !ok {
ctx.Warningf("PT_LOAD segment range overflows: %#x + %#x", segAddr, segSize)
- return 0, syserror.ENOEXEC
+ return 0, linuxerr.ENOEXEC
}
if segEnd > vdsoEnd {
ctx.Warningf("PT_LOAD segment ends beyond VDSO: %#x > %#x", segEnd, vdsoEnd)
- return 0, syserror.ENOEXEC
+ return 0, linuxerr.ENOEXEC
}
perms := progFlagsAsPerms(phdr.Flags)
if perms != hostarch.Read {
if err := m.MProtect(segPage, uint64(segSize), perms, false); err != nil {
ctx.Warningf("Unable to set PT_LOAD segment protections %+v at [%#x, %#x): %v", perms, segAddr, segEnd, err)
- return 0, syserror.ENOEXEC
+ return 0, linuxerr.ENOEXEC
}
}
}
@@ -389,3 +366,21 @@ func (v *VDSO) Release(ctx context.Context) {
v.ParamPage.DecRef(ctx)
v.vdso.DecRef(ctx)
}
+
+var vdsoSigreturnOffset = func() uint64 {
+ f, err := elf.NewFile(bytes.NewReader(vdsodata.Binary))
+ if err != nil {
+ panic(fmt.Sprintf("failed to parse vdso.so as ELF file: %v", err))
+ }
+ syms, err := f.Symbols()
+ if err != nil {
+ panic(fmt.Sprintf("failed to read symbols from vdso.so: %v", err))
+ }
+ const sigreturnSymbol = "__kernel_rt_sigreturn"
+ for _, sym := range syms {
+ if elf.ST_BIND(sym.Info) != elf.STB_LOCAL && sym.Section != elf.SHN_UNDEF && sym.Name == sigreturnSymbol {
+ return sym.Value
+ }
+ }
+ panic(fmt.Sprintf("no symbol %q in vdso.so", sigreturnSymbol))
+}()
diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD
index c30e88725..a89bfa680 100644
--- a/pkg/sentry/memmap/BUILD
+++ b/pkg/sentry/memmap/BUILD
@@ -54,7 +54,6 @@ go_library(
"//pkg/hostarch",
"//pkg/log",
"//pkg/safemem",
- "//pkg/syserror",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD
index 69aff21b6..b7d782b7f 100644
--- a/pkg/sentry/mm/BUILD
+++ b/pkg/sentry/mm/BUILD
@@ -144,7 +144,6 @@ go_library(
"//pkg/sentry/platform",
"//pkg/sentry/usage",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/tcpip/buffer",
"//pkg/usermem",
],
diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go
index b7f765cd7..d71d64580 100644
--- a/pkg/sentry/mm/aio_context.go
+++ b/pkg/sentry/mm/aio_context.go
@@ -77,15 +77,6 @@ func (mm *MemoryManager) destroyAIOContextLocked(ctx context.Context, id uint64)
return nil
}
- // Only unmaps after it assured that the address is a valid aio context to
- // prevent random memory from been unmapped.
- //
- // Note: It's possible to unmap this address and map something else into
- // the same address. Then it would be unmapping memory that it doesn't own.
- // This is, however, the way Linux implements AIO. Keeps the same [weird]
- // semantics in case anyone relies on it.
- mm.MUnmap(ctx, hostarch.Addr(id), aioRingBufferSize)
-
delete(mm.aioManager.contexts, id)
aioCtx.destroy()
return aioCtx
@@ -411,6 +402,15 @@ func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) *AIOC
return nil
}
+ // Only unmaps after it assured that the address is a valid aio context to
+ // prevent random memory from been unmapped.
+ //
+ // Note: It's possible to unmap this address and map something else into
+ // the same address. Then it would be unmapping memory that it doesn't own.
+ // This is, however, the way Linux implements AIO. Keeps the same [weird]
+ // semantics in case anyone relies on it.
+ mm.MUnmap(ctx, hostarch.Addr(id), aioRingBufferSize)
+
mm.aioManager.mu.Lock()
defer mm.aioManager.mu.Unlock()
return mm.destroyAIOContextLocked(ctx, id)
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
index 57969b26c..0fca59b64 100644
--- a/pkg/sentry/mm/mm.go
+++ b/pkg/sentry/mm/mm.go
@@ -28,6 +28,7 @@
// memmap.File locks
// mm.aioManager.mu
// mm.AIOContext.mu
+// kernel.TaskSet.mu
//
// Only mm.MemoryManager.Fork is permitted to lock mm.MemoryManager.activeMu in
// multiple mm.MemoryManagers, as it does so in a well-defined order (forked
diff --git a/pkg/sentry/mm/pma.go b/pkg/sentry/mm/pma.go
index 9f4cc238f..05cdcd8ae 100644
--- a/pkg/sentry/mm/pma.go
+++ b/pkg/sentry/mm/pma.go
@@ -324,20 +324,37 @@ func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIter
panic(fmt.Sprintf("pma %v needs to be copied for writing, but is not readable: %v", pseg.Range(), oldpma))
}
}
- // The majority of copy-on-write breaks on executable pages
- // come from:
- //
- // - The ELF loader, which must zero out bytes on the last
- // page of each segment after the end of the segment.
- //
- // - gdb's use of ptrace to insert breakpoints.
- //
- // Neither of these cases has enough spatial locality to
- // benefit from copying nearby pages, so if the vma is
- // executable, only copy the pages required.
var copyAR hostarch.AddrRange
- if vseg.ValuePtr().effectivePerms.Execute {
+ if vma := vseg.ValuePtr(); vma.effectivePerms.Execute {
+ // The majority of copy-on-write breaks on executable
+ // pages come from:
+ //
+ // - The ELF loader, which must zero out bytes on the
+ // last page of each segment after the end of the
+ // segment.
+ //
+ // - gdb's use of ptrace to insert breakpoints.
+ //
+ // Neither of these cases has enough spatial locality
+ // to benefit from copying nearby pages, so if the vma
+ // is executable, only copy the pages required.
copyAR = pseg.Range().Intersect(ar)
+ } else if vma.growsDown {
+ // In most cases, the new process will not use most of
+ // its stack before exiting or invoking execve(); it is
+ // especially unlikely to return very far down its call
+ // stack, since async-signal-safety concerns in
+ // multithreaded programs prevent the new process from
+ // being able to do much. So only copy up to one page
+ // before and after the pages required.
+ stackMaskAR := ar
+ if newStart := stackMaskAR.Start - hostarch.PageSize; newStart < stackMaskAR.Start {
+ stackMaskAR.Start = newStart
+ }
+ if newEnd := stackMaskAR.End + hostarch.PageSize; newEnd > stackMaskAR.End {
+ stackMaskAR.End = newEnd
+ }
+ copyAR = pseg.Range().Intersect(stackMaskAR)
} else {
copyAR = pseg.Range().Intersect(maskAR)
}
diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go
index 256eb4afb..dc12ad357 100644
--- a/pkg/sentry/mm/syscalls.go
+++ b/pkg/sentry/mm/syscalls.go
@@ -27,7 +27,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/syserror"
)
// HandleUserFault handles an application page fault. sp is the faulting
@@ -79,7 +78,7 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (hostar
}
length, ok := hostarch.Addr(opts.Length).RoundUp()
if !ok {
- return 0, syserror.ENOMEM
+ return 0, linuxerr.ENOMEM
}
opts.Length = uint64(length)
@@ -90,7 +89,7 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (hostar
}
// Offset + length must not overflow.
if end := opts.Offset + opts.Length; end < opts.Offset {
- return 0, syserror.ENOMEM
+ return 0, linuxerr.EOVERFLOW
}
} else {
opts.Offset = 0
@@ -253,7 +252,7 @@ func (mm *MemoryManager) MapStack(ctx context.Context) (hostarch.AddrRange, erro
ctx.Warningf("Capping stack size from RLIMIT_STACK of %v down to %v.", sz, maxStackSize)
sz = maxStackSize
} else if sz == 0 {
- return hostarch.AddrRange{}, syserror.ENOMEM
+ return hostarch.AddrRange{}, linuxerr.ENOMEM
}
szaddr := hostarch.Addr(sz)
ctx.Debugf("Allocating stack with size of %v bytes", sz)
@@ -262,7 +261,7 @@ func (mm *MemoryManager) MapStack(ctx context.Context) (hostarch.AddrRange, erro
// randomization can't be disabled.
stackEnd := mm.layout.MaxAddr - hostarch.Addr(mrand.Int63n(int64(mm.layout.MaxStackRand))).RoundDown()
if stackEnd < szaddr {
- return hostarch.AddrRange{}, syserror.ENOMEM
+ return hostarch.AddrRange{}, linuxerr.ENOMEM
}
stackStart := stackEnd - szaddr
mm.mappingMu.Lock()
@@ -500,7 +499,7 @@ func (mm *MemoryManager) MRemap(ctx context.Context, oldAddr hostarch.Addr, oldS
// Check against RLIMIT_AS.
newUsageAS := mm.usageAS - uint64(oldAR.Length()) + uint64(newAR.Length())
if limitAS := limits.FromContext(ctx).Get(limits.AS).Cur; newUsageAS > limitAS {
- return 0, syserror.ENOMEM
+ return 0, linuxerr.ENOMEM
}
if vma := vseg.ValuePtr(); vma.mappable != nil {
@@ -599,11 +598,11 @@ func (mm *MemoryManager) MProtect(addr hostarch.Addr, length uint64, realPerms h
}
rlength, ok := hostarch.Addr(length).RoundUp()
if !ok {
- return syserror.ENOMEM
+ return linuxerr.ENOMEM
}
ar, ok := addr.ToRange(uint64(rlength))
if !ok {
- return syserror.ENOMEM
+ return linuxerr.ENOMEM
}
effectivePerms := realPerms.Effective()
@@ -616,19 +615,19 @@ func (mm *MemoryManager) MProtect(addr hostarch.Addr, length uint64, realPerms h
// the non-growsDown case.
vseg := mm.vmas.LowerBoundSegment(ar.Start)
if !vseg.Ok() {
- return syserror.ENOMEM
+ return linuxerr.ENOMEM
}
if growsDown {
if !vseg.ValuePtr().growsDown {
return linuxerr.EINVAL
}
if ar.End <= vseg.Start() {
- return syserror.ENOMEM
+ return linuxerr.ENOMEM
}
ar.Start = vseg.Start()
} else {
if ar.Start < vseg.Start() {
- return syserror.ENOMEM
+ return linuxerr.ENOMEM
}
}
@@ -688,7 +687,7 @@ func (mm *MemoryManager) MProtect(addr hostarch.Addr, length uint64, realPerms h
}
vseg, _ = vseg.NextNonEmpty()
if !vseg.Ok() {
- return syserror.ENOMEM
+ return linuxerr.ENOMEM
}
}
}
@@ -724,7 +723,7 @@ func (mm *MemoryManager) Brk(ctx context.Context, addr hostarch.Addr) (hostarch.
if uint64(addr-mm.brk.Start) > limits.FromContext(ctx).Get(limits.Data).Cur {
addr = mm.brk.End
mm.mappingMu.Unlock()
- return addr, syserror.ENOMEM
+ return addr, linuxerr.ENOMEM
}
oldbrkpg, _ := mm.brk.End.RoundUp()
@@ -798,7 +797,7 @@ func (mm *MemoryManager) MLock(ctx context.Context, addr hostarch.Addr, length u
}
if newLockedAS := mm.lockedAS + uint64(ar.Length()) - mm.mlockedBytesRangeLocked(ar); newLockedAS > mlockLimit {
mm.mappingMu.Unlock()
- return syserror.ENOMEM
+ return linuxerr.ENOMEM
}
}
}
@@ -835,7 +834,7 @@ func (mm *MemoryManager) MLock(ctx context.Context, addr hostarch.Addr, length u
mm.vmas.MergeAdjacent(ar)
if unmapped {
mm.mappingMu.Unlock()
- return syserror.ENOMEM
+ return linuxerr.ENOMEM
}
if mode == memmap.MLockEager {
@@ -850,7 +849,7 @@ func (mm *MemoryManager) MLock(ctx context.Context, addr hostarch.Addr, length u
// case, which is converted to ENOMEM by mlock.
mm.activeMu.Unlock()
mm.mappingMu.RUnlock()
- return syserror.ENOMEM
+ return linuxerr.ENOMEM
}
_, _, err := mm.getPMAsLocked(ctx, vseg, vseg.Range().Intersect(ar), hostarch.NoAccess)
if err != nil {
@@ -858,7 +857,7 @@ func (mm *MemoryManager) MLock(ctx context.Context, addr hostarch.Addr, length u
mm.mappingMu.RUnlock()
// Linux: mm/mlock.c:__mlock_posix_error_return()
if linuxerr.Equals(linuxerr.EFAULT, err) {
- return syserror.ENOMEM
+ return linuxerr.ENOMEM
}
if linuxerr.Equals(linuxerr.ENOMEM, err) {
return linuxerr.EAGAIN
@@ -917,7 +916,7 @@ func (mm *MemoryManager) MLockAll(ctx context.Context, opts MLockAllOpts) error
}
if uint64(mm.vmas.Span()) > mlockLimit {
mm.mappingMu.Unlock()
- return syserror.ENOMEM
+ return linuxerr.ENOMEM
}
}
}
@@ -1040,7 +1039,7 @@ func (mm *MemoryManager) SetDontFork(addr hostarch.Addr, length uint64, dontfork
}
if mm.vmas.SpanRange(ar) != ar.Length() {
- return syserror.ENOMEM
+ return linuxerr.ENOMEM
}
return nil
}
@@ -1099,7 +1098,7 @@ func (mm *MemoryManager) Decommit(addr hostarch.Addr, length uint64) error {
// to the rest (but returns ENOMEM from the system call, as it should)." -
// madvise(2)
if mm.vmas.SpanRange(ar) != ar.Length() {
- return syserror.ENOMEM
+ return linuxerr.ENOMEM
}
return nil
}
@@ -1123,11 +1122,11 @@ func (mm *MemoryManager) MSync(ctx context.Context, addr hostarch.Addr, length u
}
la, ok := hostarch.Addr(length).RoundUp()
if !ok {
- return syserror.ENOMEM
+ return linuxerr.ENOMEM
}
ar, ok := addr.ToRange(uint64(la))
if !ok {
- return syserror.ENOMEM
+ return linuxerr.ENOMEM
}
mm.mappingMu.RLock()
@@ -1135,7 +1134,7 @@ func (mm *MemoryManager) MSync(ctx context.Context, addr hostarch.Addr, length u
vseg := mm.vmas.LowerBoundSegment(ar.Start)
if !vseg.Ok() {
mm.mappingMu.RUnlock()
- return syserror.ENOMEM
+ return linuxerr.ENOMEM
}
var unmapped bool
lastEnd := ar.Start
@@ -1184,7 +1183,7 @@ func (mm *MemoryManager) MSync(ctx context.Context, addr hostarch.Addr, length u
}
if unmapped {
- return syserror.ENOMEM
+ return linuxerr.ENOMEM
}
return nil
}
diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go
index 5f8ab7ca3..e34b7a2f7 100644
--- a/pkg/sentry/mm/vma.go
+++ b/pkg/sentry/mm/vma.go
@@ -25,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/syserror"
)
// Preconditions:
@@ -59,7 +58,7 @@ func (mm *MemoryManager) createVMALocked(ctx context.Context, opts memmap.MMapOp
newUsageAS -= uint64(mm.vmas.SpanRange(ar))
}
if limitAS := limits.FromContext(ctx).Get(limits.AS).Cur; newUsageAS > limitAS {
- return vmaIterator{}, hostarch.AddrRange{}, syserror.ENOMEM
+ return vmaIterator{}, hostarch.AddrRange{}, linuxerr.ENOMEM
}
if opts.MLockMode != memmap.MLockNone {
@@ -178,7 +177,7 @@ func (mm *MemoryManager) findAvailableLocked(length uint64, opts findAvailableOp
// Fixed mappings accept only the requested address.
if opts.Fixed {
- return 0, syserror.ENOMEM
+ return 0, linuxerr.ENOMEM
}
// Prefer hugepage alignment if a hugepage or more is requested.
@@ -216,7 +215,7 @@ func (mm *MemoryManager) findLowestAvailableLocked(length, alignment uint64, bou
return gr.Start, nil
}
}
- return 0, syserror.ENOMEM
+ return 0, linuxerr.ENOMEM
}
// Preconditions: mm.mappingMu must be locked.
@@ -236,7 +235,7 @@ func (mm *MemoryManager) findHighestAvailableLocked(length, alignment uint64, bo
return start, nil
}
}
- return 0, syserror.ENOMEM
+ return 0, linuxerr.ENOMEM
}
// Preconditions: mm.mappingMu must be locked.
diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD
index d351869ef..496a9fd97 100644
--- a/pkg/sentry/pgalloc/BUILD
+++ b/pkg/sentry/pgalloc/BUILD
@@ -97,7 +97,6 @@ go_library(
"//pkg/state",
"//pkg/state/wire",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/usermem",
"@org_golang_x_sys//unix:go_default_library",
],
diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go
index 0c8542485..68e17d343 100644
--- a/pkg/sentry/pgalloc/pgalloc.go
+++ b/pkg/sentry/pgalloc/pgalloc.go
@@ -39,7 +39,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
// MemoryFile is a memmap.File whose pages may be allocated to arbitrary
@@ -404,7 +403,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (memmap.File
// Find a range in the underlying file.
fr, ok := findAvailableRange(&f.usage, f.fileSize, length, alignment)
if !ok {
- return memmap.FileRange{}, syserror.ENOMEM
+ return memmap.FileRange{}, linuxerr.ENOMEM
}
// Expand the file if needed.
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index 8a490b3de..834d72408 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -1,13 +1,26 @@
load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
+go_template_instance(
+ name = "atomicptr_machine",
+ out = "atomicptr_machine_unsafe.go",
+ package = "kvm",
+ prefix = "machine",
+ template = "//pkg/sync/atomicptr:generic_atomicptr",
+ types = {
+ "Value": "machine",
+ },
+)
+
go_library(
name = "kvm",
srcs = [
"address_space.go",
"address_space_amd64.go",
"address_space_arm64.go",
+ "atomicptr_machine_unsafe.go",
"bluepill.go",
"bluepill_allocator.go",
"bluepill_amd64.go",
@@ -50,7 +63,6 @@ go_library(
"//pkg/procid",
"//pkg/ring0",
"//pkg/ring0/pagetables",
- "//pkg/safecopy",
"//pkg/seccomp",
"//pkg/sentry/arch",
"//pkg/sentry/arch/fpu",
@@ -58,6 +70,7 @@ go_library(
"//pkg/sentry/platform",
"//pkg/sentry/platform/interrupt",
"//pkg/sentry/time",
+ "//pkg/sighandling",
"//pkg/sync",
"@org_golang_x_sys//unix:go_default_library",
],
@@ -69,10 +82,17 @@ go_test(
"kvm_amd64_test.go",
"kvm_amd64_test.s",
"kvm_arm64_test.go",
+ "kvm_safecopy_test.go",
"kvm_test.go",
"virtual_map_test.go",
],
library = ":kvm",
+ # FIXME(gvisor.dev/issue/3374): Not working with all build systems.
+ nogo = False,
+ # cgo has to be disabled. We have seen libc that blocks all signals and
+ # calls mmap from pthread_create, but we use SIGSYS to trap mmap system
+ # calls.
+ pure = True,
tags = [
"manual",
"nogotsan",
@@ -81,8 +101,10 @@ go_test(
deps = [
"//pkg/abi/linux",
"//pkg/hostarch",
+ "//pkg/memutil",
"//pkg/ring0",
"//pkg/ring0/pagetables",
+ "//pkg/safecopy",
"//pkg/sentry/arch",
"//pkg/sentry/arch/fpu",
"//pkg/sentry/platform",
diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go
index bb9967b9f..5be2215ed 100644
--- a/pkg/sentry/platform/kvm/bluepill.go
+++ b/pkg/sentry/platform/kvm/bluepill.go
@@ -19,8 +19,8 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/ring0"
- "gvisor.dev/gvisor/pkg/safecopy"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sighandling"
)
// bluepill enters guest mode.
@@ -61,6 +61,9 @@ var (
// This is called by bluepillHandler.
savedHandler uintptr
+ // savedSigsysHandler is a pointer to the previos handler of the SIGSYS signals.
+ savedSigsysHandler uintptr
+
// dieTrampolineAddr is the address of dieTrampoline.
dieTrampolineAddr uintptr
)
@@ -94,7 +97,7 @@ func (c *vCPU) die(context *arch.SignalContext64, msg string) {
func init() {
// Install the handler.
- if err := safecopy.ReplaceSignalHandler(bluepillSignal, addrOfSighandler(), &savedHandler); err != nil {
+ if err := sighandling.ReplaceSignalHandler(bluepillSignal, addrOfSighandler(), &savedHandler); err != nil {
panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err))
}
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go
index 0567c8d32..b2db2bb9f 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.go
@@ -71,10 +71,6 @@ func (c *vCPU) KernelSyscall() {
if regs.Rax != ^uint64(0) {
regs.Rip -= 2 // Rewind.
}
- // We only trigger a bluepill entry in the bluepill function, and can
- // therefore be guaranteed that there is no floating point state to be
- // loaded on resuming from halt. We only worry about saving on exit.
- ring0.SaveFloatingPoint(c.floatingPointState.BytePointer()) // escapes: no.
// N.B. Since KernelSyscall is called when the kernel makes a syscall,
// FS_BASE is already set for correct execution of this function.
//
@@ -112,8 +108,6 @@ func (c *vCPU) KernelException(vector ring0.Vector) {
regs.Rip = 0
}
// See above.
- ring0.SaveFloatingPoint(c.floatingPointState.BytePointer()) // escapes: no.
- // See above.
ring0.HaltAndWriteFSBase(regs) // escapes: no, reload host segment.
}
@@ -144,5 +138,5 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) {
// Set the context pointer to the saved floating point state. This is
// where the guest data has been serialized, the kernel will restore
// from this new pointer value.
- context.Fpstate = uint64(uintptrValue(c.floatingPointState.BytePointer()))
+ context.Fpstate = uint64(uintptrValue(c.FloatingPointState().BytePointer())) // escapes: no.
}
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_amd64.s
index c2a1dca11..5d8358f64 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64.s
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.s
@@ -32,6 +32,8 @@
// This is checked as the source of the fault.
#define CLI $0xfa
+#define SYS_MMAP 9
+
// See bluepill.go.
TEXT ·bluepill(SB),NOSPLIT,$0
begin:
@@ -95,6 +97,31 @@ TEXT ·addrOfSighandler(SB), $0-8
MOVQ AX, ret+0(FP)
RET
+TEXT ·sigsysHandler(SB),NOSPLIT,$0
+ // Check if the signal is from the kernel.
+ MOVQ $1, CX
+ CMPL CX, 0x8(SI)
+ JNE fallback
+
+ MOVL CONTEXT_RAX(DX), CX
+ CMPL CX, $SYS_MMAP
+ JNE fallback
+ PUSHQ DX // First argument (context).
+ CALL ·seccompMmapHandler(SB) // Call the handler.
+ POPQ DX // Discard the argument.
+ RET
+fallback:
+ // Jump to the previous signal handler.
+ XORQ CX, CX
+ MOVQ ·savedSigsysHandler(SB), AX
+ JMP AX
+
+// func addrOfSighandler() uintptr
+TEXT ·addrOfSigsysHandler(SB), $0-8
+ MOVQ $·sigsysHandler(SB), AX
+ MOVQ AX, ret+0(FP)
+ RET
+
// dieTrampoline: see bluepill.go, bluepill_amd64_unsafe.go for documentation.
TEXT ·dieTrampoline(SB),NOSPLIT,$0
PUSHQ BX // First argument (vCPU).
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go
index acb0cb05f..df772d620 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.go
@@ -70,7 +70,7 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) {
lazyVfp := c.GetLazyVFP()
if lazyVfp != 0 {
- fpsimd := fpsimdPtr(c.floatingPointState.BytePointer())
+ fpsimd := fpsimdPtr(c.FloatingPointState().BytePointer()) // escapes: no
context.Fpsimd64.Fpsr = fpsimd.Fpsr
context.Fpsimd64.Fpcr = fpsimd.Fpcr
context.Fpsimd64.Vregs = fpsimd.Vregs
@@ -90,12 +90,12 @@ func (c *vCPU) KernelSyscall() {
fpDisableTrap := ring0.CPACREL1()
if fpDisableTrap != 0 {
- fpsimd := fpsimdPtr(c.floatingPointState.BytePointer())
+ fpsimd := fpsimdPtr(c.FloatingPointState().BytePointer()) // escapes: no
fpcr := ring0.GetFPCR()
fpsr := ring0.GetFPSR()
fpsimd.Fpcr = uint32(fpcr)
fpsimd.Fpsr = uint32(fpsr)
- ring0.SaveVRegs(c.floatingPointState.BytePointer())
+ ring0.SaveVRegs(c.FloatingPointState().BytePointer()) // escapes: no
}
ring0.Halt()
@@ -114,12 +114,12 @@ func (c *vCPU) KernelException(vector ring0.Vector) {
fpDisableTrap := ring0.CPACREL1()
if fpDisableTrap != 0 {
- fpsimd := fpsimdPtr(c.floatingPointState.BytePointer())
+ fpsimd := fpsimdPtr(c.FloatingPointState().BytePointer()) // escapes: no
fpcr := ring0.GetFPCR()
fpsr := ring0.GetFPSR()
fpsimd.Fpcr = uint32(fpcr)
fpsimd.Fpsr = uint32(fpsr)
- ring0.SaveVRegs(c.floatingPointState.BytePointer())
+ ring0.SaveVRegs(c.FloatingPointState().BytePointer()) // escapes: no
}
ring0.Halt()
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s
index 308f2a951..9690e3772 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.s
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.s
@@ -29,9 +29,12 @@
// Only limited use of the context is done in the assembly stub below, most is
// done in the Go handlers.
#define SIGINFO_SIGNO 0x0
+#define SIGINFO_CODE 0x8
#define CONTEXT_PC 0x1B8
#define CONTEXT_R0 0xB8
+#define SYS_MMAP 222
+
// getTLS returns the value of TPIDR_EL0 register.
TEXT ·getTLS(SB),NOSPLIT,$0-8
MRS TPIDR_EL0, R1
@@ -98,6 +101,37 @@ TEXT ·addrOfSighandler(SB), $0-8
MOVD R0, ret+0(FP)
RET
+// The arguments are the following:
+//
+// R0 - The signal number.
+// R1 - Pointer to siginfo_t structure.
+// R2 - Pointer to ucontext structure.
+//
+TEXT ·sigsysHandler(SB),NOSPLIT,$0
+ // si_code should be SYS_SECCOMP.
+ MOVD SIGINFO_CODE(R1), R7
+ CMPW $1, R7
+ BNE fallback
+
+ CMPW $SYS_MMAP, R8
+ BNE fallback
+
+ MOVD R2, 8(RSP)
+ BL ·seccompMmapHandler(SB) // Call the handler.
+
+ RET
+
+fallback:
+ // Jump to the previous signal handler.
+ MOVD ·savedHandler(SB), R7
+ B (R7)
+
+// func addrOfSighandler() uintptr
+TEXT ·addrOfSigsysHandler(SB), $0-8
+ MOVD $·sigsysHandler(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
// dieTrampoline: see bluepill.go, bluepill_arm64_unsafe.go for documentation.
TEXT ·dieTrampoline(SB),NOSPLIT,$0
// R0: Fake the old PC as caller
diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go
index 8fd8287b3..7a3c97c5a 100644
--- a/pkg/sentry/platform/kvm/bluepill_fault.go
+++ b/pkg/sentry/platform/kvm/bluepill_fault.go
@@ -55,11 +55,7 @@ func calculateBluepillFault(physical uintptr, phyRegions []physicalRegion) (virt
}
// Adjust the block to match our size.
- physicalStart = alignedPhysical & faultBlockMask
- if physicalStart < pr.physical {
- // Bound the starting point to the start of the region.
- physicalStart = pr.physical
- }
+ physicalStart = pr.physical + (alignedPhysical-pr.physical)&faultBlockMask
virtualStart = pr.virtual + (physicalStart - pr.physical)
physicalEnd := physicalStart + faultBlockSize
if physicalEnd > end {
diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go
index 0f0c1e73b..e38ca05c0 100644
--- a/pkg/sentry/platform/kvm/bluepill_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go
@@ -193,36 +193,8 @@ func bluepillHandler(context unsafe.Pointer) {
return
}
- // Increment the fault count.
- atomic.AddUint32(&c.faults, 1)
-
- // For MMIO, the physical address is the first data item.
- physical = uintptr(c.runData.data[0])
- virtual, ok := handleBluepillFault(c.machine, physical, physicalRegions, _KVM_MEM_FLAGS_NONE)
- if !ok {
- c.die(bluepillArchContext(context), "invalid physical address")
- return
- }
-
- // We now need to fill in the data appropriately. KVM
- // expects us to provide the result of the given MMIO
- // operation in the runData struct. This is safe
- // because, if a fault occurs here, the same fault
- // would have occurred in guest mode. The kernel should
- // not create invalid page table mappings.
- data := (*[8]byte)(unsafe.Pointer(&c.runData.data[1]))
- length := (uintptr)((uint32)(c.runData.data[2]))
- write := (uint8)(((c.runData.data[2] >> 32) & 0xff)) != 0
- for i := uintptr(0); i < length; i++ {
- b := bytePtr(uintptr(virtual) + i)
- if write {
- // Write to the given address.
- *b = data[i]
- } else {
- // Read from the given address.
- data[i] = *b
- }
- }
+ c.die(bluepillArchContext(context), "exit_mmio")
+ return
case _KVM_EXIT_IRQ_WINDOW_OPEN:
bluepillStopGuest(c)
case _KVM_EXIT_SHUTDOWN:
diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go
index aac0fdffe..ad6863646 100644
--- a/pkg/sentry/platform/kvm/kvm.go
+++ b/pkg/sentry/platform/kvm/kvm.go
@@ -77,7 +77,11 @@ var (
// OpenDevice opens the KVM device at /dev/kvm and returns the File.
func OpenDevice() (*os.File, error) {
- f, err := os.OpenFile("/dev/kvm", unix.O_RDWR, 0)
+ dev, ok := os.LookupEnv("GVISOR_KVM_DEV")
+ if !ok {
+ dev = "/dev/kvm"
+ }
+ f, err := os.OpenFile(dev, unix.O_RDWR, 0)
if err != nil {
return nil, fmt.Errorf("error opening /dev/kvm: %v", err)
}
diff --git a/pkg/sentry/platform/kvm/kvm_safecopy_test.go b/pkg/sentry/platform/kvm/kvm_safecopy_test.go
new file mode 100644
index 000000000..9a87c9e6f
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm_safecopy_test.go
@@ -0,0 +1,104 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// FIXME(gvisor.dev/issue//6629): These tests don't pass on ARM64.
+//
+//go:build amd64
+// +build amd64
+
+package kvm
+
+import (
+ "fmt"
+ "os"
+ "testing"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/memutil"
+ "gvisor.dev/gvisor/pkg/safecopy"
+)
+
+func testSafecopy(t *testing.T, mapSize uintptr, fileSize uintptr, testFunc func(t *testing.T, c *vCPU, addr uintptr)) {
+ memfd, err := memutil.CreateMemFD(fmt.Sprintf("kvm_test_%d", os.Getpid()), 0)
+ if err != nil {
+ t.Errorf("error creating memfd: %v", err)
+ }
+
+ memfile := os.NewFile(uintptr(memfd), "kvm_test")
+ memfile.Truncate(int64(fileSize))
+ kvmTest(t, nil, func(c *vCPU) bool {
+ const n = 10
+ mappings := make([]uintptr, n)
+ defer func() {
+ for i := 0; i < n && mappings[i] != 0; i++ {
+ unix.RawSyscall(
+ unix.SYS_MUNMAP,
+ mappings[i], mapSize, 0)
+ }
+ }()
+ for i := 0; i < n; i++ {
+ addr, _, errno := unix.RawSyscall6(
+ unix.SYS_MMAP,
+ 0,
+ mapSize,
+ unix.PROT_READ|unix.PROT_WRITE,
+ unix.MAP_SHARED|unix.MAP_FILE,
+ uintptr(memfile.Fd()),
+ 0)
+ if errno != 0 {
+ t.Errorf("error mapping file: %v", errno)
+ }
+ mappings[i] = addr
+ testFunc(t, c, addr)
+ }
+ return false
+ })
+}
+
+func TestSafecopySigbus(t *testing.T) {
+ mapSize := uintptr(faultBlockSize)
+ fileSize := mapSize - hostarch.PageSize
+ buf := make([]byte, hostarch.PageSize)
+ testSafecopy(t, mapSize, fileSize, func(t *testing.T, c *vCPU, addr uintptr) {
+ want := safecopy.BusError{addr + fileSize}
+ bluepill(c)
+ _, err := safecopy.CopyIn(buf, unsafe.Pointer(addr+fileSize))
+ if err != want {
+ t.Errorf("expected error: got %v, want %v", err, want)
+ }
+ })
+}
+
+func TestSafecopy(t *testing.T) {
+ mapSize := uintptr(faultBlockSize)
+ fileSize := mapSize
+ testSafecopy(t, mapSize, fileSize, func(t *testing.T, c *vCPU, addr uintptr) {
+ want := uint32(0x12345678)
+ bluepill(c)
+ _, err := safecopy.SwapUint32(unsafe.Pointer(addr+fileSize-8), want)
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+ bluepill(c)
+ val, err := safecopy.LoadUint32(unsafe.Pointer(addr + fileSize - 8))
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+ if val != want {
+ t.Errorf("incorrect value: got %x, want %x", val, want)
+ }
+ })
+}
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index e7092a756..f1f7e4ea4 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -17,16 +17,20 @@ package kvm
import (
"fmt"
"runtime"
+ gosync "sync"
"sync/atomic"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/procid"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
+ "gvisor.dev/gvisor/pkg/seccomp"
ktime "gvisor.dev/gvisor/pkg/sentry/time"
+ "gvisor.dev/gvisor/pkg/sighandling"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -35,6 +39,9 @@ type machine struct {
// fd is the vm fd.
fd int
+ // machinePoolIndex is the index in the machinePool array.
+ machinePoolIndex uint32
+
// nextSlot is the next slot for setMemoryRegion.
//
// This must be accessed atomically. If nextSlot is ^uint32(0), then
@@ -192,6 +199,10 @@ func (m *machine) newVCPU() *vCPU {
return c // Done.
}
+// readOnlyGuestRegions contains regions that have to be mapped read-only into
+// the guest physical address space. Right now, it is used on arm64 only.
+var readOnlyGuestRegions []region
+
// newMachine returns a new VM context.
func newMachine(vm int) (*machine, error) {
// Create the machine.
@@ -227,6 +238,10 @@ func newMachine(vm int) (*machine, error) {
m.upperSharedPageTables.MarkReadOnlyShared()
m.kernel.PageTables = pagetables.NewWithUpper(newAllocator(), m.upperSharedPageTables, ring0.KernelStartAddress)
+ // Install seccomp rules to trap runtime mmap system calls. They will
+ // be handled by seccompMmapHandler.
+ seccompMmapRules(m)
+
// Apply the physical mappings. Note that these mappings may point to
// guest physical addresses that are not actually available. These
// physical pages are mapped on demand, see kernel_unsafe.go.
@@ -241,32 +256,11 @@ func newMachine(vm int) (*machine, error) {
return true // Keep iterating.
})
- var physicalRegionsReadOnly []physicalRegion
- var physicalRegionsAvailable []physicalRegion
-
- physicalRegionsReadOnly = rdonlyRegionsForSetMem()
- physicalRegionsAvailable = availableRegionsForSetMem()
-
- // Map all read-only regions.
- for _, r := range physicalRegionsReadOnly {
- m.mapPhysical(r.physical, r.length, physicalRegionsReadOnly, _KVM_MEM_READONLY)
- }
-
// Ensure that the currently mapped virtual regions are actually
// available in the VM. Note that this doesn't guarantee no future
// faults, however it should guarantee that everything is available to
// ensure successful vCPU entry.
- applyVirtualRegions(func(vr virtualRegion) {
- if excludeVirtualRegion(vr) {
- return // skip region.
- }
-
- for _, r := range physicalRegionsReadOnly {
- if vr.virtual == r.virtual {
- return
- }
- }
-
+ mapRegion := func(vr region, flags uint32) {
for virtual := vr.virtual; virtual < vr.virtual+vr.length; {
physical, length, ok := translateToPhysical(virtual)
if !ok {
@@ -280,9 +274,32 @@ func newMachine(vm int) (*machine, error) {
}
// Ensure the physical range is mapped.
- m.mapPhysical(physical, length, physicalRegionsAvailable, _KVM_MEM_FLAGS_NONE)
+ m.mapPhysical(physical, length, physicalRegions, flags)
virtual += length
}
+ }
+
+ for _, vr := range readOnlyGuestRegions {
+ mapRegion(vr, _KVM_MEM_READONLY)
+ }
+
+ applyVirtualRegions(func(vr virtualRegion) {
+ if excludeVirtualRegion(vr) {
+ return // skip region.
+ }
+ for _, r := range readOnlyGuestRegions {
+ if vr.virtual == r.virtual {
+ return
+ }
+ }
+ // Take into account that the stack can grow down.
+ if vr.filename == "[stack]" {
+ vr.virtual -= 1 << 20
+ vr.length += 1 << 20
+ }
+
+ mapRegion(vr.region, 0)
+
})
// Initialize architecture state.
@@ -352,6 +369,10 @@ func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalReg
func (m *machine) Destroy() {
runtime.SetFinalizer(m, nil)
+ machinePoolMu.Lock()
+ machinePool[m.machinePoolIndex].Store(nil)
+ machinePoolMu.Unlock()
+
// Destroy vCPUs.
for _, c := range m.vCPUsByID {
if c == nil {
@@ -519,15 +540,21 @@ func (c *vCPU) lock() {
//
//go:nosplit
func (c *vCPU) unlock() {
- if atomic.CompareAndSwapUint32(&c.state, vCPUUser|vCPUGuest, vCPUGuest) {
+ origState := atomicbitops.CompareAndSwapUint32(&c.state, vCPUUser|vCPUGuest, vCPUGuest)
+ if origState == vCPUUser|vCPUGuest {
// Happy path: no exits are forced, and we can continue
// executing on our merry way with a single atomic access.
return
}
// Clear the lock.
- origState := atomic.LoadUint32(&c.state)
- atomicbitops.AndUint32(&c.state, ^vCPUUser)
+ for {
+ state := atomicbitops.CompareAndSwapUint32(&c.state, origState, origState&^vCPUUser)
+ if state == origState {
+ break
+ }
+ origState = state
+ }
switch origState {
case vCPUUser:
// Normal state.
@@ -677,3 +704,72 @@ func (c *vCPU) setSystemTimeLegacy() error {
}
}
}
+
+const machinePoolSize = 16
+
+// machinePool is enumerated from the seccompMmapHandler signal handler
+var (
+ machinePool [machinePoolSize]machineAtomicPtr
+ machinePoolLen uint32
+ machinePoolMu sync.Mutex
+ seccompMmapRulesOnce gosync.Once
+)
+
+func sigsysHandler()
+func addrOfSigsysHandler() uintptr
+
+// seccompMmapRules adds seccomp rules to trap mmap system calls that will be
+// handled in seccompMmapHandler.
+func seccompMmapRules(m *machine) {
+ seccompMmapRulesOnce.Do(func() {
+ // Install the handler.
+ if err := sighandling.ReplaceSignalHandler(unix.SIGSYS, addrOfSigsysHandler(), &savedSigsysHandler); err != nil {
+ panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err))
+ }
+ rules := []seccomp.RuleSet{}
+ rules = append(rules, []seccomp.RuleSet{
+ // Trap mmap system calls and handle them in sigsysGoHandler
+ {
+ Rules: seccomp.SyscallRules{
+ unix.SYS_MMAP: {
+ {
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ /* MAP_DENYWRITE is ignored and used only for filtering. */
+ seccomp.MaskedEqual(unix.MAP_DENYWRITE, 0),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_TRAP,
+ },
+ }...)
+ instrs, err := seccomp.BuildProgram(rules, linux.SECCOMP_RET_ALLOW, linux.SECCOMP_RET_ALLOW)
+ if err != nil {
+ panic(fmt.Sprintf("failed to build rules: %v", err))
+ }
+ // Perform the actual installation.
+ if err := seccomp.SetFilter(instrs); err != nil {
+ panic(fmt.Sprintf("failed to set filter: %v", err))
+ }
+ })
+
+ machinePoolMu.Lock()
+ n := atomic.LoadUint32(&machinePoolLen)
+ i := uint32(0)
+ for ; i < n; i++ {
+ if machinePool[i].Load() == nil {
+ break
+ }
+ }
+ if i == n {
+ if i == machinePoolSize {
+ machinePoolMu.Unlock()
+ panic("machinePool is full")
+ }
+ atomic.AddUint32(&machinePoolLen, 1)
+ }
+ machinePool[i].Store(m)
+ m.machinePoolIndex = i
+ machinePoolMu.Unlock()
+}
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index a96634381..5bc023899 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -29,7 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
- "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
ktime "gvisor.dev/gvisor/pkg/sentry/time"
)
@@ -72,10 +71,6 @@ type vCPUArchState struct {
//
// This starts above fixedKernelPCID.
PCIDs *pagetables.PCIDs
-
- // floatingPointState is the floating point state buffer used in guest
- // to host transitions. See usage in bluepill_amd64.go.
- floatingPointState fpu.State
}
const (
@@ -152,12 +147,6 @@ func (c *vCPU) initArchState() error {
return fmt.Errorf("error setting user registers: %v", errno)
}
- // Allocate some floating point state save area for the local vCPU.
- // This will be saved prior to leaving the guest, and we restore from
- // this always. We cannot use the pointer in the context alone because
- // we don't know how large the area there is in reality.
- c.floatingPointState = fpu.NewState()
-
// Set the time offset to the host native time.
return c.setSystemTime()
}
@@ -309,22 +298,6 @@ func loadByte(ptr *byte) byte {
return *ptr
}
-// prefaultFloatingPointState touches each page of the floating point state to
-// be sure that its physical pages are mapped.
-//
-// Otherwise the kernel can trigger KVM_EXIT_MMIO and an instruction that
-// triggered a fault will be emulated by the kvm kernel code, but it can't
-// emulate instructions like xsave and xrstor.
-//
-//go:nosplit
-func prefaultFloatingPointState(data *fpu.State) {
- size := len(*data)
- for i := 0; i < size; i += hostarch.PageSize {
- loadByte(&(*data)[i])
- }
- loadByte(&(*data)[size-1])
-}
-
// SwitchToUser unpacks architectural-details.
func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo) (hostarch.AccessType, error) {
// Check for canonical addresses.
@@ -355,11 +328,6 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo)
// allocations occur.
entersyscall()
bluepill(c)
- // The root table physical page has to be mapped to not fault in iret
- // or sysret after switching into a user address space. sysret and
- // iret are in the upper half that is global and already mapped.
- switchOpts.PageTables.PrefaultRootTable()
- prefaultFloatingPointState(switchOpts.FloatingPointState)
vector = c.CPU.SwitchToUser(switchOpts)
exitsyscall()
@@ -522,3 +490,7 @@ func (m *machine) getNewVCPU() *vCPU {
}
return nil
}
+
+func archPhysicalRegions(physicalRegions []physicalRegion) []physicalRegion {
+ return physicalRegions
+}
diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
index de798bb2c..fbacea9ad 100644
--- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
@@ -161,3 +161,15 @@ func (c *vCPU) getSystemRegisters(sregs *systemRegs) unix.Errno {
}
return 0
}
+
+//go:nosplit
+func seccompMmapSyscall(context unsafe.Pointer) (uintptr, uintptr, unix.Errno) {
+ ctx := bluepillArchContext(context)
+
+ // MAP_DENYWRITE is deprecated and ignored by kernel. We use it only for seccomp filters.
+ addr, _, e := unix.RawSyscall6(uintptr(ctx.Rax), uintptr(ctx.Rdi), uintptr(ctx.Rsi),
+ uintptr(ctx.Rdx), uintptr(ctx.R10)|unix.MAP_DENYWRITE, uintptr(ctx.R8), uintptr(ctx.R9))
+ ctx.Rax = uint64(addr)
+
+ return addr, uintptr(ctx.Rsi), e
+}
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
index 7937a8481..31998a600 100644
--- a/pkg/sentry/platform/kvm/machine_arm64.go
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
- "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
)
@@ -40,10 +39,6 @@ type vCPUArchState struct {
//
// This starts above fixedKernelPCID.
PCIDs *pagetables.PCIDs
-
- // floatingPointState is the floating point state buffer used in guest
- // to host transitions. See usage in bluepill_arm64.go.
- floatingPointState fpu.State
}
const (
@@ -110,18 +105,128 @@ func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) {
return phyRegions
}
+// archPhysicalRegions fills readOnlyGuestRegions and allocates separate
+// physical regions form them.
+func archPhysicalRegions(physicalRegions []physicalRegion) []physicalRegion {
+ applyVirtualRegions(func(vr virtualRegion) {
+ if excludeVirtualRegion(vr) {
+ return // skip region.
+ }
+ if !vr.accessType.Write {
+ readOnlyGuestRegions = append(readOnlyGuestRegions, vr.region)
+ }
+ })
+
+ rdRegions := readOnlyGuestRegions[:]
+
+ // Add an unreachable region.
+ rdRegions = append(rdRegions, region{
+ virtual: 0xffffffffffffffff,
+ length: 0,
+ })
+
+ var regions []physicalRegion
+ addValidRegion := func(r *physicalRegion, virtual, length uintptr) {
+ if length == 0 {
+ return
+ }
+ regions = append(regions, physicalRegion{
+ region: region{
+ virtual: virtual,
+ length: length,
+ },
+ physical: r.physical + (virtual - r.virtual),
+ })
+ }
+ i := 0
+ for _, pr := range physicalRegions {
+ start := pr.virtual
+ end := pr.virtual + pr.length
+ for start < end {
+ rdRegion := rdRegions[i]
+ rdStart := rdRegion.virtual
+ rdEnd := rdRegion.virtual + rdRegion.length
+ if rdEnd <= start {
+ i++
+ continue
+ }
+ if rdStart > start {
+ newEnd := rdStart
+ if end < rdStart {
+ newEnd = end
+ }
+ addValidRegion(&pr, start, newEnd-start)
+ start = rdStart
+ continue
+ }
+ if rdEnd < end {
+ addValidRegion(&pr, start, rdEnd-start)
+ start = rdEnd
+ continue
+ }
+ addValidRegion(&pr, start, end-start)
+ start = end
+ }
+ }
+
+ return regions
+}
+
// Get all available physicalRegions.
-func availableRegionsForSetMem() (phyRegions []physicalRegion) {
- var excludeRegions []region
+func availableRegionsForSetMem() []physicalRegion {
+ var excludedRegions []region
applyVirtualRegions(func(vr virtualRegion) {
if !vr.accessType.Write {
- excludeRegions = append(excludeRegions, vr.region)
+ excludedRegions = append(excludedRegions, vr.region)
}
})
- phyRegions = computePhysicalRegions(excludeRegions)
+ // Add an unreachable region.
+ excludedRegions = append(excludedRegions, region{
+ virtual: 0xffffffffffffffff,
+ length: 0,
+ })
- return phyRegions
+ var regions []physicalRegion
+ addValidRegion := func(r *physicalRegion, virtual, length uintptr) {
+ if length == 0 {
+ return
+ }
+ regions = append(regions, physicalRegion{
+ region: region{
+ virtual: virtual,
+ length: length,
+ },
+ physical: r.physical + (virtual - r.virtual),
+ })
+ }
+ i := 0
+ for _, pr := range physicalRegions {
+ start := pr.virtual
+ end := pr.virtual + pr.length
+ for start < end {
+ er := excludedRegions[i]
+ excludeEnd := er.virtual + er.length
+ excludeStart := er.virtual
+ if excludeEnd < start {
+ i++
+ continue
+ }
+ if excludeStart < start {
+ start = excludeEnd
+ i++
+ continue
+ }
+ rend := excludeStart
+ if rend > end {
+ rend = end
+ }
+ addValidRegion(&pr, start, rend-start)
+ start = excludeEnd
+ }
+ }
+
+ return regions
}
// nonCanonical generates a canonical address return.
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 1a4a9ce7d..e73d5c544 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -28,7 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
- "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
ktime "gvisor.dev/gvisor/pkg/sentry/time"
)
@@ -159,8 +158,6 @@ func (c *vCPU) initArchState() error {
c.PCIDs = pagetables.NewPCIDs(fixedKernelPCID+1, poolPCIDs)
}
- c.floatingPointState = fpu.NewState()
-
return c.setSystemTime()
}
@@ -333,3 +330,15 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo)
}
}
+
+//go:nosplit
+func seccompMmapSyscall(context unsafe.Pointer) (uintptr, uintptr, unix.Errno) {
+ ctx := bluepillArchContext(context)
+
+ // MAP_DENYWRITE is deprecated and ignored by kernel. We use it only for seccomp filters.
+ addr, _, e := unix.RawSyscall6(uintptr(ctx.Regs[8]), uintptr(ctx.Regs[0]), uintptr(ctx.Regs[1]),
+ uintptr(ctx.Regs[2]), uintptr(ctx.Regs[3])|unix.MAP_DENYWRITE, uintptr(ctx.Regs[4]), uintptr(ctx.Regs[5]))
+ ctx.Regs[0] = uint64(addr)
+
+ return addr, uintptr(ctx.Regs[1]), e
+}
diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go
index cc3a1253b..cf3a4e7c9 100644
--- a/pkg/sentry/platform/kvm/machine_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_unsafe.go
@@ -171,3 +171,46 @@ func (c *vCPU) setSignalMask() error {
return nil
}
+
+// seccompMmapHandler is a signal handler for runtime mmap system calls
+// that are trapped by seccomp.
+//
+// It executes the mmap syscall with specified arguments and maps a new region
+// to the guest.
+//
+//go:nosplit
+func seccompMmapHandler(context unsafe.Pointer) {
+ addr, length, errno := seccompMmapSyscall(context)
+ if errno != 0 {
+ return
+ }
+
+ for i := uint32(0); i < atomic.LoadUint32(&machinePoolLen); i++ {
+ m := machinePool[i].Load()
+ if m == nil {
+ continue
+ }
+
+ // Map the new region to the guest.
+ vr := region{
+ virtual: addr,
+ length: length,
+ }
+ for virtual := vr.virtual; virtual < vr.virtual+vr.length; {
+ physical, length, ok := translateToPhysical(virtual)
+ if !ok {
+ // This must be an invalid region that was
+ // knocked out by creation of the physical map.
+ return
+ }
+ if virtual+length > vr.virtual+vr.length {
+ // Cap the length to the end of the area.
+ length = vr.virtual + vr.length - virtual
+ }
+
+ // Ensure the physical range is mapped.
+ m.mapPhysical(physical, length, physicalRegions, _KVM_MEM_FLAGS_NONE)
+ virtual += length
+ }
+ }
+}
diff --git a/pkg/sentry/platform/kvm/physical_map.go b/pkg/sentry/platform/kvm/physical_map.go
index d812e6c26..9864d1258 100644
--- a/pkg/sentry/platform/kvm/physical_map.go
+++ b/pkg/sentry/platform/kvm/physical_map.go
@@ -168,6 +168,9 @@ func computePhysicalRegions(excludedRegions []region) (physicalRegions []physica
}
addValidRegion(lastExcludedEnd, ring0.MaximumUserAddress-lastExcludedEnd)
+ // Do arch-specific actions on physical regions.
+ physicalRegions = archPhysicalRegions(physicalRegions)
+
// Dump our all physical regions.
for _, r := range physicalRegions {
log.Infof("physicalRegion: virtual [%x,%x) => physical [%x,%x)",
diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go
index 6d0ba8252..346a10043 100644
--- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go
+++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go
@@ -30,8 +30,8 @@ import (
func TLSWorks() bool
// SetTestTarget sets the rip appropriately.
-func SetTestTarget(regs *arch.Registers, fn func()) {
- regs.Pc = uint64(reflect.ValueOf(fn).Pointer())
+func SetTestTarget(regs *arch.Registers, fn uintptr) {
+ regs.Pc = uint64(fn)
}
// SetTouchTarget sets rax appropriately.
diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
index 7348c29a5..42876245a 100644
--- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
+++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
@@ -28,6 +28,11 @@ TEXT ·Getpid(SB),NOSPLIT,$0
SVC
RET
+TEXT ·AddrOfGetpid(SB),NOSPLIT,$0-8
+ MOVD $·Getpid(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
TEXT ·Touch(SB),NOSPLIT,$0
start:
MOVD 0(R8), R1
@@ -35,21 +40,41 @@ start:
SVC
B start
+TEXT ·AddrOfTouch(SB),NOSPLIT,$0-8
+ MOVD $·Touch(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
TEXT ·HaltLoop(SB),NOSPLIT,$0
start:
HLT
B start
+TEXT ·AddOfHaltLoop(SB),NOSPLIT,$0-8
+ MOVD $·HaltLoop(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
// This function simulates a loop of syscall.
TEXT ·SyscallLoop(SB),NOSPLIT,$0
start:
SVC
B start
+TEXT ·AddrOfSyscallLoop(SB),NOSPLIT,$0-8
+ MOVD $·SyscallLoop(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
TEXT ·SpinLoop(SB),NOSPLIT,$0
start:
B start
+TEXT ·AddrOfSpinLoop(SB),NOSPLIT,$0-8
+ MOVD $·SpinLoop(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
TEXT ·TLSWorks(SB),NOSPLIT,$0-8
NO_LOCAL_POINTERS
MOVD $0x6789, R5
@@ -125,6 +150,11 @@ TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0
SVC
RET // never reached
+TEXT ·AddrOfTwiddleRegsSyscall(SB),NOSPLIT,$0-8
+ MOVD $·TwiddleRegsSyscall(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0
TWIDDLE_REGS()
MSR R10, TPIDR_EL0
@@ -132,3 +162,8 @@ TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0
// Branch to Register branches unconditionally to an address in <Rn>.
JMP (R6) // <=> br x6, must fault
RET // never reached
+
+TEXT ·AddrOfTwiddleRegsFault(SB),NOSPLIT,$0-8
+ MOVD $·TwiddleRegsFault(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
diff --git a/pkg/sentry/seccheck/BUILD b/pkg/sentry/seccheck/BUILD
new file mode 100644
index 000000000..35feb969f
--- /dev/null
+++ b/pkg/sentry/seccheck/BUILD
@@ -0,0 +1,58 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_fieldenum:defs.bzl", "go_fieldenum")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+licenses(["notice"])
+
+go_fieldenum(
+ name = "seccheck_fieldenum",
+ srcs = [
+ "clone.go",
+ "execve.go",
+ "exit.go",
+ "task.go",
+ ],
+ out = "seccheck_fieldenum.go",
+ package = "seccheck",
+)
+
+go_template_instance(
+ name = "seqatomic_checkerslice",
+ out = "seqatomic_checkerslice_unsafe.go",
+ package = "seccheck",
+ suffix = "CheckerSlice",
+ template = "//pkg/sync/seqatomic:generic_seqatomic",
+ types = {
+ "Value": "[]Checker",
+ },
+)
+
+go_library(
+ name = "seccheck",
+ srcs = [
+ "clone.go",
+ "execve.go",
+ "exit.go",
+ "seccheck.go",
+ "seccheck_fieldenum.go",
+ "seqatomic_checkerslice_unsafe.go",
+ "task.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/gohacks",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "seccheck_test",
+ size = "small",
+ srcs = ["seccheck_test.go"],
+ library = ":seccheck",
+ deps = ["//pkg/context"],
+)
diff --git a/pkg/sentry/seccheck/clone.go b/pkg/sentry/seccheck/clone.go
new file mode 100644
index 000000000..7546fa021
--- /dev/null
+++ b/pkg/sentry/seccheck/clone.go
@@ -0,0 +1,53 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seccheck
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// CloneInfo contains information used by the Clone checkpoint.
+//
+// +fieldenum Clone
+type CloneInfo struct {
+ // Invoker identifies the invoking thread.
+ Invoker TaskInfo
+
+ // Credentials are the invoking thread's credentials.
+ Credentials *auth.Credentials
+
+ // Args contains the arguments to kernel.Task.Clone().
+ Args linux.CloneArgs
+
+ // Created identifies the created thread.
+ Created TaskInfo
+}
+
+// CloneReq returns fields required by the Clone checkpoint.
+func (s *state) CloneReq() CloneFieldSet {
+ return s.cloneReq.Load()
+}
+
+// Clone is called at the Clone checkpoint.
+func (s *state) Clone(ctx context.Context, mask CloneFieldSet, info *CloneInfo) error {
+ for _, c := range s.getCheckers() {
+ if err := c.Clone(ctx, mask, *info); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/seccheck/execve.go b/pkg/sentry/seccheck/execve.go
new file mode 100644
index 000000000..f36e0730e
--- /dev/null
+++ b/pkg/sentry/seccheck/execve.go
@@ -0,0 +1,65 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seccheck
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// ExecveInfo contains information used by the Execve checkpoint.
+//
+// +fieldenum Execve
+type ExecveInfo struct {
+ // Invoker identifies the invoking thread.
+ Invoker TaskInfo
+
+ // Credentials are the invoking thread's credentials.
+ Credentials *auth.Credentials
+
+ // BinaryPath is a path to the executable binary file being switched to in
+ // the mount namespace in which it was opened.
+ BinaryPath string
+
+ // Argv is the new process image's argument vector.
+ Argv []string
+
+ // Env is the new process image's environment variables.
+ Env []string
+
+ // BinaryMode is the executable binary file's mode.
+ BinaryMode uint16
+
+ // BinarySHA256 is the SHA-256 hash of the executable binary file.
+ //
+ // Note that this requires reading the entire file into memory, which is
+ // likely to be extremely slow.
+ BinarySHA256 [32]byte
+}
+
+// ExecveReq returns fields required by the Execve checkpoint.
+func (s *state) ExecveReq() ExecveFieldSet {
+ return s.execveReq.Load()
+}
+
+// Execve is called at the Execve checkpoint.
+func (s *state) Execve(ctx context.Context, mask ExecveFieldSet, info *ExecveInfo) error {
+ for _, c := range s.getCheckers() {
+ if err := c.Execve(ctx, mask, *info); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/seccheck/exit.go b/pkg/sentry/seccheck/exit.go
new file mode 100644
index 000000000..69cb6911c
--- /dev/null
+++ b/pkg/sentry/seccheck/exit.go
@@ -0,0 +1,57 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seccheck
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+)
+
+// ExitNotifyParentInfo contains information used by the ExitNotifyParent
+// checkpoint.
+//
+// +fieldenum ExitNotifyParent
+type ExitNotifyParentInfo struct {
+ // Exiter identifies the exiting thread. Note that by the checkpoint's
+ // definition, Exiter.ThreadID == Exiter.ThreadGroupID and
+ // Exiter.ThreadStartTime == Exiter.ThreadGroupStartTime, so requesting
+ // ThreadGroup* fields is redundant.
+ Exiter TaskInfo
+
+ // ExitStatus is the exiting thread group's exit status, as reported
+ // by wait*().
+ ExitStatus linux.WaitStatus
+}
+
+// ExitNotifyParentReq returns fields required by the ExitNotifyParent
+// checkpoint.
+func (s *state) ExitNotifyParentReq() ExitNotifyParentFieldSet {
+ return s.exitNotifyParentReq.Load()
+}
+
+// ExitNotifyParent is called at the ExitNotifyParent checkpoint.
+//
+// The ExitNotifyParent checkpoint occurs when a zombied thread group leader,
+// not waiting for exit acknowledgement from a non-parent ptracer, becomes the
+// last non-dead thread in its thread group and notifies its parent of its
+// exiting.
+func (s *state) ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info *ExitNotifyParentInfo) error {
+ for _, c := range s.getCheckers() {
+ if err := c.ExitNotifyParent(ctx, mask, *info); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/seccheck/seccheck.go b/pkg/sentry/seccheck/seccheck.go
new file mode 100644
index 000000000..e13274096
--- /dev/null
+++ b/pkg/sentry/seccheck/seccheck.go
@@ -0,0 +1,158 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package seccheck defines a structure for dynamically-configured security
+// checks in the sentry.
+package seccheck
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// A Point represents a checkpoint, a point at which a security check occurs.
+type Point uint
+
+// PointX represents the checkpoint X.
+const (
+ PointClone Point = iota
+ PointExecve
+ PointExitNotifyParent
+ // Add new Points above this line.
+ pointLength
+
+ numPointBitmaskUint32s = (int(pointLength)-1)/32 + 1
+)
+
+// A Checker performs security checks at checkpoints.
+//
+// Each Checker method X is called at checkpoint X; if the method may return a
+// non-nil error and does so, it causes the checked operation to fail
+// immediately (without calling subsequent Checkers) and return the error. The
+// info argument contains information relevant to the check. The mask argument
+// indicates what fields in info are valid; the mask should usually be a
+// superset of fields requested by the Checker's corresponding CheckerReq, but
+// may be missing requested fields in some cases (e.g. if the Checker is
+// registered concurrently with invocations of checkpoints).
+type Checker interface {
+ Clone(ctx context.Context, mask CloneFieldSet, info CloneInfo) error
+ Execve(ctx context.Context, mask ExecveFieldSet, info ExecveInfo) error
+ ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info ExitNotifyParentInfo) error
+}
+
+// CheckerDefaults may be embedded by implementations of Checker to obtain
+// no-op implementations of Checker methods that may be explicitly overridden.
+type CheckerDefaults struct{}
+
+// Clone implements Checker.Clone.
+func (CheckerDefaults) Clone(ctx context.Context, mask CloneFieldSet, info CloneInfo) error {
+ return nil
+}
+
+// Execve implements Checker.Execve.
+func (CheckerDefaults) Execve(ctx context.Context, mask ExecveFieldSet, info ExecveInfo) error {
+ return nil
+}
+
+// ExitNotifyParent implements Checker.ExitNotifyParent.
+func (CheckerDefaults) ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info ExitNotifyParentInfo) error {
+ return nil
+}
+
+// CheckerReq indicates what checkpoints a corresponding Checker runs at, and
+// what information it requires at those checkpoints.
+type CheckerReq struct {
+ // Points are the set of checkpoints for which the corresponding Checker
+ // must be called. Note that methods not specified in Points may still be
+ // called; implementations of Checker may embed CheckerDefaults to obtain
+ // no-op implementations of Checker methods.
+ Points []Point
+
+ // All of the following fields indicate what fields in the corresponding
+ // XInfo struct will be requested at the corresponding checkpoint.
+ Clone CloneFields
+ Execve ExecveFields
+ ExitNotifyParent ExitNotifyParentFields
+}
+
+// Global is the method receiver of all seccheck functions.
+var Global state
+
+// state is the type of global, and is separated out for testing.
+type state struct {
+ // registrationMu serializes all changes to the set of registered Checkers
+ // for all checkpoints.
+ registrationMu sync.Mutex
+
+ // enabledPoints is a bitmask of checkpoints for which at least one Checker
+ // is registered.
+ //
+ // enabledPoints is accessed using atomic memory operations. Mutation of
+ // enabledPoints is serialized by registrationMu.
+ enabledPoints [numPointBitmaskUint32s]uint32
+
+ // registrationSeq supports store-free atomic reads of registeredCheckers.
+ registrationSeq sync.SeqCount
+
+ // checkers is the set of all registered Checkers in order of execution.
+ //
+ // checkers is accessed using instantiations of SeqAtomic functions.
+ // Mutation of checkers is serialized by registrationMu.
+ checkers []Checker
+
+ // All of the following xReq variables indicate what fields in the
+ // corresponding XInfo struct have been requested by any registered
+ // checker, are accessed using atomic memory operations, and are mutated
+ // with registrationMu locked.
+ cloneReq CloneFieldSet
+ execveReq ExecveFieldSet
+ exitNotifyParentReq ExitNotifyParentFieldSet
+}
+
+// AppendChecker registers the given Checker to execute at checkpoints. The
+// Checker will execute after all previously-registered Checkers, and only if
+// those Checkers return a nil error.
+func (s *state) AppendChecker(c Checker, req *CheckerReq) {
+ s.registrationMu.Lock()
+ defer s.registrationMu.Unlock()
+
+ s.cloneReq.AddFieldsLoadable(req.Clone)
+ s.execveReq.AddFieldsLoadable(req.Execve)
+ s.exitNotifyParentReq.AddFieldsLoadable(req.ExitNotifyParent)
+
+ s.appendCheckerLocked(c)
+ for _, p := range req.Points {
+ word, bit := p/32, p%32
+ atomic.StoreUint32(&s.enabledPoints[word], s.enabledPoints[word]|(uint32(1)<<bit))
+ }
+}
+
+// Enabled returns true if any Checker is registered for the given checkpoint.
+func (s *state) Enabled(p Point) bool {
+ word, bit := p/32, p%32
+ return atomic.LoadUint32(&s.enabledPoints[word])&(uint32(1)<<bit) != 0
+}
+
+func (s *state) getCheckers() []Checker {
+ return SeqAtomicLoadCheckerSlice(&s.registrationSeq, &s.checkers)
+}
+
+// Preconditions: s.registrationMu must be locked.
+func (s *state) appendCheckerLocked(c Checker) {
+ s.registrationSeq.BeginWrite()
+ s.checkers = append(s.checkers, c)
+ s.registrationSeq.EndWrite()
+}
diff --git a/pkg/sentry/seccheck/seccheck_test.go b/pkg/sentry/seccheck/seccheck_test.go
new file mode 100644
index 000000000..687810d18
--- /dev/null
+++ b/pkg/sentry/seccheck/seccheck_test.go
@@ -0,0 +1,157 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seccheck
+
+import (
+ "errors"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/context"
+)
+
+type testChecker struct {
+ CheckerDefaults
+
+ onClone func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error
+}
+
+// Clone implements Checker.Clone.
+func (c *testChecker) Clone(ctx context.Context, mask CloneFieldSet, info CloneInfo) error {
+ if c.onClone == nil {
+ return nil
+ }
+ return c.onClone(ctx, mask, info)
+}
+
+func TestNoChecker(t *testing.T) {
+ var s state
+ if s.Enabled(PointClone) {
+ t.Errorf("Enabled(PointClone): got true, wanted false")
+ }
+}
+
+func TestCheckerNotRegisteredForPoint(t *testing.T) {
+ var s state
+ s.AppendChecker(&testChecker{}, &CheckerReq{})
+ if s.Enabled(PointClone) {
+ t.Errorf("Enabled(PointClone): got true, wanted false")
+ }
+}
+
+func TestCheckerRegistered(t *testing.T) {
+ var s state
+ checkerCalled := false
+ s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error {
+ checkerCalled = true
+ return nil
+ }}, &CheckerReq{
+ Points: []Point{PointClone},
+ Clone: CloneFields{
+ Credentials: true,
+ },
+ })
+
+ if !s.Enabled(PointClone) {
+ t.Errorf("Enabled(PointClone): got false, wanted true")
+ }
+ if !s.CloneReq().Contains(CloneFieldCredentials) {
+ t.Errorf("CloneReq().Contains(CloneFieldCredentials): got false, wanted true")
+ }
+ if err := s.Clone(context.Background(), CloneFieldSet{}, &CloneInfo{}); err != nil {
+ t.Errorf("Clone(): got %v, wanted nil", err)
+ }
+ if !checkerCalled {
+ t.Errorf("Clone() did not call Checker.Clone()")
+ }
+}
+
+func TestMultipleCheckersRegistered(t *testing.T) {
+ var s state
+ checkersCalled := [2]bool{}
+ s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error {
+ checkersCalled[0] = true
+ return nil
+ }}, &CheckerReq{
+ Points: []Point{PointClone},
+ Clone: CloneFields{
+ Args: true,
+ },
+ })
+ s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error {
+ checkersCalled[1] = true
+ return nil
+ }}, &CheckerReq{
+ Points: []Point{PointClone},
+ Clone: CloneFields{
+ Created: TaskFields{
+ ThreadID: true,
+ },
+ },
+ })
+
+ if !s.Enabled(PointClone) {
+ t.Errorf("Enabled(PointClone): got false, wanted true")
+ }
+ // CloneReq() should return the union of requested fields from all calls to
+ // AppendChecker.
+ req := s.CloneReq()
+ if !req.Contains(CloneFieldArgs) {
+ t.Errorf("req.Contains(CloneFieldArgs): got false, wanted true")
+ }
+ if !req.Created.Contains(TaskFieldThreadID) {
+ t.Errorf("req.Created.Contains(TaskFieldThreadID): got false, wanted true")
+ }
+ if err := s.Clone(context.Background(), CloneFieldSet{}, &CloneInfo{}); err != nil {
+ t.Errorf("Clone(): got %v, wanted nil", err)
+ }
+ for i := range checkersCalled {
+ if !checkersCalled[i] {
+ t.Errorf("Clone() did not call Checker.Clone() index %d", i)
+ }
+ }
+}
+
+func TestCheckpointReturnsFirstCheckerError(t *testing.T) {
+ errFirstChecker := errors.New("first Checker error")
+ errSecondChecker := errors.New("second Checker error")
+
+ var s state
+ checkersCalled := [2]bool{}
+ s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error {
+ checkersCalled[0] = true
+ return errFirstChecker
+ }}, &CheckerReq{
+ Points: []Point{PointClone},
+ })
+ s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error {
+ checkersCalled[1] = true
+ return errSecondChecker
+ }}, &CheckerReq{
+ Points: []Point{PointClone},
+ })
+
+ if !s.Enabled(PointClone) {
+ t.Errorf("Enabled(PointClone): got false, wanted true")
+ }
+ if err := s.Clone(context.Background(), CloneFieldSet{}, &CloneInfo{}); err != errFirstChecker {
+ t.Errorf("Clone(): got %v, wanted %v", err, errFirstChecker)
+ }
+ if !checkersCalled[0] {
+ t.Errorf("Clone() did not call first Checker")
+ }
+ if checkersCalled[1] {
+ t.Errorf("Clone() called second Checker")
+ }
+}
diff --git a/pkg/sentry/seccheck/task.go b/pkg/sentry/seccheck/task.go
new file mode 100644
index 000000000..1dee33203
--- /dev/null
+++ b/pkg/sentry/seccheck/task.go
@@ -0,0 +1,39 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seccheck
+
+import (
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+)
+
+// TaskInfo contains information unambiguously identifying a single thread
+// and/or its containing process.
+//
+// +fieldenum Task
+type TaskInfo struct {
+ // ThreadID is the thread's ID in the root PID namespace.
+ ThreadID int32
+
+ // ThreadStartTime is the thread's CLOCK_REALTIME start time.
+ ThreadStartTime ktime.Time
+
+ // ThreadGroupID is the thread's group leader's ID in the root PID
+ // namespace.
+ ThreadGroupID int32
+
+ // ThreadGroupStartTime is the thread's group leader's CLOCK_REALTIME start
+ // time.
+ ThreadGroupStartTime ktime.Time
+}
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index 7ee89a735..00f925166 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -4,7 +4,10 @@ package(licenses = ["notice"])
go_library(
name = "socket",
- srcs = ["socket.go"],
+ srcs = [
+ "socket.go",
+ "socket_state.go",
+ ],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 00a5e729a..6077b2150 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -29,10 +29,9 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "time"
)
-const maxInt = int(^uint(0) >> 1)
-
// SCMCredentials represents a SCM_CREDENTIALS socket control message.
type SCMCredentials interface {
transport.CredentialsControlMessage
@@ -78,7 +77,7 @@ func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) {
}
// Files implements SCMRights.Files.
-func (fs *RightsFiles) Files(ctx context.Context, max int) (RightsFiles, bool) {
+func (fs *RightsFiles) Files(_ context.Context, max int) (RightsFiles, bool) {
n := max
var trunc bool
if l := len(*fs); n > l {
@@ -124,7 +123,7 @@ func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) ([]int32
break
}
- fds = append(fds, int32(fd))
+ fds = append(fds, fd)
}
return fds, trunc
}
@@ -300,8 +299,8 @@ func alignSlice(buf []byte, align uint) []byte {
}
// PackTimestamp packs a SO_TIMESTAMP socket control message.
-func PackTimestamp(t *kernel.Task, timestamp int64, buf []byte) []byte {
- timestampP := linux.NsecToTimeval(timestamp)
+func PackTimestamp(t *kernel.Task, timestamp time.Time, buf []byte) []byte {
+ timestampP := linux.NsecToTimeval(timestamp.UnixNano())
return putCmsgStruct(
buf,
linux.SOL_SOCKET,
@@ -355,6 +354,17 @@ func PackIPPacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPPacketIn
)
}
+// PackIPv6PacketInfo packs an IPV6_PKTINFO socket control message.
+func PackIPv6PacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPv6PacketInfo, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IPV6,
+ linux.IPV6_PKTINFO,
+ t.Arch().Width(),
+ packetInfo,
+ )
+}
+
// PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message.
func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, buf []byte) []byte {
var level uint32
@@ -412,6 +422,10 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt
buf = PackIPPacketInfo(t, &cmsgs.IP.PacketInfo, buf)
}
+ if cmsgs.IP.HasIPv6PacketInfo {
+ buf = PackIPv6PacketInfo(t, &cmsgs.IP.IPv6PacketInfo, buf)
+ }
+
if cmsgs.IP.OriginalDstAddress != nil {
buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf)
}
@@ -453,6 +467,10 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
space += cmsgSpace(t, linux.SizeOfControlMessageIPPacketInfo)
}
+ if cmsgs.IP.HasIPv6PacketInfo {
+ space += cmsgSpace(t, linux.SizeOfControlMessageIPv6PacketInfo)
+ }
+
if cmsgs.IP.OriginalDstAddress != nil {
space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes())
}
@@ -526,7 +544,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
}
var ts linux.Timeval
ts.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval])
- cmsgs.IP.Timestamp = ts.ToNsecCapped()
+ cmsgs.IP.Timestamp = ts.ToTime()
cmsgs.IP.HasTimestamp = true
i += bits.AlignUp(length, width)
diff --git a/pkg/sentry/socket/control/control_test.go b/pkg/sentry/socket/control/control_test.go
index 7e28a0cef..1b04e1bbc 100644
--- a/pkg/sentry/socket/control/control_test.go
+++ b/pkg/sentry/socket/control/control_test.go
@@ -50,7 +50,7 @@ func TestParse(t *testing.T) {
want := socket.ControlMessages{
IP: socket.IPControlMessages{
HasTimestamp: true,
- Timestamp: ts.ToNsecCapped(),
+ Timestamp: ts.ToTime(),
},
}
if diff := cmp.Diff(want, cmsg); diff != "" {
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index 3950caa0f..4ea89f9d0 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -38,7 +38,6 @@ go_library(
"//pkg/sentry/socket/control",
"//pkg/sentry/vfs",
"//pkg/syserr",
- "//pkg/syserror",
"//pkg/tcpip",
"//pkg/tcpip/stack",
"//pkg/usermem",
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 38cb2c99c..6e2318f75 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -35,7 +35,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/control"
"gvisor.dev/gvisor/pkg/syserr"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -112,7 +111,7 @@ func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
}
return readv(s.fd, safemem.IovecsFromBlockSeq(dsts))
}))
- return int64(n), err
+ return n, err
}
// Write implements fs.FileOperations.Write.
@@ -135,7 +134,7 @@ func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
}
return writev(s.fd, safemem.IovecsFromBlockSeq(srcs))
}))
- return int64(n), err
+ return n, err
}
// Socket implements socket.Provider.Socket.
@@ -181,7 +180,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto
}
// Pair implements socket.Provider.Pair.
-func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+func (p *socketProvider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
// Not supported by AF_INET/AF_INET6.
return nil, nil, nil
}
@@ -208,7 +207,7 @@ type socketOpsCommon struct {
// Release implements fs.FileOperations.Release.
func (s *socketOpsCommon) Release(context.Context) {
fdnotifier.RemoveFD(int32(s.fd))
- unix.Close(s.fd)
+ _ = unix.Close(s.fd)
}
// Readiness implements waiter.Waitable.Readiness.
@@ -219,13 +218,13 @@ func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
// EventRegister implements waiter.Waitable.EventRegister.
func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
s.queue.EventRegister(e, mask)
- fdnotifier.UpdateFD(int32(s.fd))
+ _ = fdnotifier.UpdateFD(int32(s.fd))
}
// EventUnregister implements waiter.Waitable.EventUnregister.
func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
s.queue.EventUnregister(e)
- fdnotifier.UpdateFD(int32(s.fd))
+ _ = fdnotifier.UpdateFD(int32(s.fd))
}
// Connect implements socket.Socket.Connect.
@@ -288,7 +287,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int,
fd, syscallErr := accept4(s.fd, peerAddrPtr, peerAddrlenPtr, unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC)
if blocking {
var ch chan struct{}
- for syscallErr == syserror.ErrWouldBlock {
+ for syscallErr == linuxerr.ErrWouldBlock {
if ch != nil {
if syscallErr = t.Block(ch); syscallErr != nil {
break
@@ -317,7 +316,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int,
if kernel.VFS2Enabled {
f, err := newVFS2Socket(t, s.family, s.stype, s.protocol, fd, uint32(flags&unix.SOCK_NONBLOCK))
if err != nil {
- unix.Close(fd)
+ _ = unix.Close(fd)
return 0, nil, 0, err
}
defer f.DecRef(t)
@@ -329,7 +328,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int,
} else {
f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&unix.SOCK_NONBLOCK != 0)
if err != nil {
- unix.Close(fd)
+ _ = unix.Close(fd)
return 0, nil, 0, err
}
defer f.DecRef(t)
@@ -344,7 +343,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int,
}
// Bind implements socket.Socket.Bind.
-func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+func (s *socketOpsCommon) Bind(_ *kernel.Task, sockaddr []byte) *syserr.Error {
if len(sockaddr) > sizeofSockaddr {
sockaddr = sockaddr[:sizeofSockaddr]
}
@@ -357,12 +356,12 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
// Listen implements socket.Socket.Listen.
-func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
+func (s *socketOpsCommon) Listen(_ *kernel.Task, backlog int) *syserr.Error {
return syserr.FromError(unix.Listen(s.fd, backlog))
}
// Shutdown implements socket.Socket.Shutdown.
-func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
+func (s *socketOpsCommon) Shutdown(_ *kernel.Task, how int) *syserr.Error {
switch how {
case unix.SHUT_RD, unix.SHUT_WR, unix.SHUT_RDWR:
return syserr.FromError(unix.Shutdown(s.fd, how))
@@ -372,7 +371,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
+func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, _ hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
if outLen < 0 {
return nil, syserr.ErrInvalidArgument
}
@@ -402,7 +401,7 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr
case linux.TCP_NODELAY:
optlen = sizeofInt32
case linux.TCP_INFO:
- optlen = int(linux.SizeOfTCPInfo)
+ optlen = linux.SizeOfTCPInfo
}
}
@@ -535,7 +534,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
n, err := copyToDst()
// recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT.
if flags&(unix.MSG_DONTWAIT|unix.MSG_ERRQUEUE) == 0 {
- for err == syserror.ErrWouldBlock {
+ for err == linuxerr.ErrWouldBlock {
// We only expect blocking to come from the actual syscall, in which
// case it can't have returned any data.
if n != 0 {
@@ -580,7 +579,7 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s
controlMessages.IP.HasTimestamp = true
ts := linux.Timeval{}
ts.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfTimeval])
- controlMessages.IP.Timestamp = ts.ToNsecCapped()
+ controlMessages.IP.Timestamp = ts.ToTime()
}
case linux.SOL_IP:
@@ -707,7 +706,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
var ch chan struct{}
n, err := src.CopyInTo(t, sendmsgFromBlocks)
if flags&unix.MSG_DONTWAIT == 0 {
- for err == syserror.ErrWouldBlock {
+ for err == linuxerr.ErrWouldBlock {
// We only expect blocking to come from the actual syscall, in which
// case it can't have returned any data.
if n != 0 {
@@ -716,7 +715,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
if ch != nil {
if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
- err = syserror.ErrWouldBlock
+ err = linuxerr.ErrWouldBlock
}
break
}
@@ -735,7 +734,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
func translateIOSyscallError(err error) error {
if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
- return syserror.ErrWouldBlock
+ return linuxerr.ErrWouldBlock
}
return err
}
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index ea56f39c1..b9c15daab 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -647,7 +647,7 @@ func (jt *JumpTarget) id() targetID {
}
// Action implements stack.Target.Action.
-func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) {
+func (jt *JumpTarget) Action(*stack.PacketBuffer, stack.Hook, *stack.Route, stack.AddressableEndpoint) (stack.RuleVerdict, int) {
return stack.RuleJump, jt.RuleNum
}
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index ed85404da..9710a15ee 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -36,7 +36,6 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/sync",
"//pkg/syserr",
- "//pkg/syserror",
"//pkg/tcpip",
"//pkg/usermem",
"//pkg/waiter",
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 5c3ae26f8..ed5fa9c38 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -39,7 +39,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
@@ -530,7 +529,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
}
- if n, err := doRead(); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ if n, err := doRead(); err != linuxerr.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
var mflags int
if n < int64(r.MsgSize) {
mflags |= linux.MSG_TRUNC
@@ -548,7 +547,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
defer s.EventUnregister(&e)
for {
- if n, err := doRead(); err != syserror.ErrWouldBlock {
+ if n, err := doRead(); err != linuxerr.ErrWouldBlock {
var mflags int
if n < int64(r.MsgSize) {
mflags |= linux.MSG_TRUNC
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index e828982eb..075f61cda 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"device.go",
"netstack.go",
+ "netstack_state.go",
"netstack_vfs2.go",
"provider.go",
"provider_vfs2.go",
@@ -42,13 +43,13 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/sync",
"//pkg/syserr",
- "//pkg/syserror",
"//pkg/tcpip",
"//pkg/tcpip/header",
"//pkg/tcpip/link/tun",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/usermem",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 9b844b0c0..030c6c8e4 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -56,12 +56,11 @@ import (
"gvisor.dev/gvisor/pkg/sentry/unimpl"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
- "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -275,6 +274,7 @@ var Metrics = tcpip.Stats{
ChecksumErrors: mustCreateMetric("/netstack/tcp/checksum_errors", "Number of segments dropped due to bad checksums."),
FailedPortReservations: mustCreateMetric("/netstack/tcp/failed_port_reservations", "Number of time TCP failed to reserve a port."),
SegmentsAckedWithDSACK: mustCreateMetric("/netstack/tcp/segments_acked_with_dsack", "Number of segments for which DSACK was received."),
+ SpuriousRecovery: mustCreateMetric("/netstack/tcp/spurious_recovery", "Number of times the connection entered loss recovery spuriously."),
},
UDP: tcpip.UDPStats{
PacketsReceived: mustCreateMetric("/netstack/udp/packets_received", "Number of UDP datagrams received via HandlePacket."),
@@ -379,9 +379,9 @@ type socketOpsCommon struct {
// timestampValid indicates whether timestamp for SIOCGSTAMP has been
// set. It is protected by readMu.
timestampValid bool
- // timestampNS holds the timestamp to use with SIOCTSTAMP. It is only
+ // timestamp holds the timestamp to use with SIOCTSTAMP. It is only
// valid when timestampValid is true. It is protected by readMu.
- timestampNS int64
+ timestamp time.Time `state:".(int64)"`
// TODO(b/153685824): Move this to SocketOptions.
// sockOptInq corresponds to TCP_INQ.
@@ -411,13 +411,25 @@ var sockAddrInetSize = (*linux.SockAddrInet)(nil).SizeBytes()
var sockAddrInet6Size = (*linux.SockAddrInet6)(nil).SizeBytes()
var sockAddrLinkSize = (*linux.SockAddrLink)(nil).SizeBytes()
-// bytesToIPAddress converts an IPv4 or IPv6 address from the user to the
-// netstack representation taking any addresses into account.
-func bytesToIPAddress(addr []byte) tcpip.Address {
- if bytes.Equal(addr, make([]byte, 4)) || bytes.Equal(addr, make([]byte, 16)) {
- return ""
+// minSockAddrLen returns the minimum length in bytes of a socket address for
+// the socket's family.
+func (s *socketOpsCommon) minSockAddrLen() int {
+ const addressFamilySize = 2
+
+ switch s.family {
+ case linux.AF_UNIX:
+ return addressFamilySize
+ case linux.AF_INET:
+ return sockAddrInetSize
+ case linux.AF_INET6:
+ return sockAddrInet6Size
+ case linux.AF_PACKET:
+ return sockAddrLinkSize
+ case linux.AF_UNSPEC:
+ return addressFamilySize
+ default:
+ panic(fmt.Sprintf("s.family unrecognized = %d", s.family))
}
- return tcpip.Address(addr)
}
func (s *socketOpsCommon) isPacketBased() bool {
@@ -448,7 +460,7 @@ func (s *socketOpsCommon) Release(ctx context.Context) {
t := kernel.TaskFromContext(ctx)
start := t.Kernel().MonotonicClock().Now()
deadline := start.Add(v.Timeout)
- t.BlockWithDeadline(ch, true, deadline)
+ _ = t.BlockWithDeadline(ch, true, deadline)
}
}
@@ -459,7 +471,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
}
n, _, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false)
if err == syserr.ErrWouldBlock {
- return int64(n), syserror.ErrWouldBlock
+ return int64(n), linuxerr.ErrWouldBlock
}
if err != nil {
return 0, err.ToError()
@@ -468,7 +480,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
}
// WriteTo implements fs.FileOperations.WriteTo.
-func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) {
+func (s *SocketOperations) WriteTo(_ context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) {
s.readMu.Lock()
defer s.readMu.Unlock()
@@ -492,14 +504,14 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
r := src.Reader(ctx)
n, err := s.Endpoint.Write(r, tcpip.WriteOptions{})
if _, ok := err.(*tcpip.ErrWouldBlock); ok {
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
if err != nil {
return 0, syserr.TranslateNetstackError(err).ToError()
}
if n < src.NumBytes() {
- return n, syserror.ErrWouldBlock
+ return n, linuxerr.ErrWouldBlock
}
return n, nil
@@ -523,7 +535,7 @@ func (l *limitedPayloader) Len() int {
}
// ReadFrom implements fs.FileOperations.ReadFrom.
-func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
+func (s *SocketOperations) ReadFrom(_ context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
f := limitedPayloader{
inner: io.LimitedReader{
R: r,
@@ -546,16 +558,21 @@ func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
return s.Endpoint.Readiness(mask)
}
-func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error {
+// checkFamily returns true iff the specified address family may be used with
+// the socket.
+//
+// If exact is true, then the specified address family must be an exact match
+// with the socket's family.
+func (s *socketOpsCommon) checkFamily(family uint16, exact bool) bool {
if family == uint16(s.family) {
- return nil
+ return true
}
if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 {
if !s.Endpoint.SocketOptions().GetV6Only() {
- return nil
+ return true
}
}
- return syserr.ErrInvalidArgument
+ return false
}
// mapFamily maps the AF_INET ANY address to the IPv4-mapped IPv6 ANY if the
@@ -588,8 +605,8 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool
return syserr.TranslateNetstackError(err)
}
- if err := s.checkFamily(family, false /* exact */); err != nil {
- return err
+ if !s.checkFamily(family, false /* exact */) {
+ return syserr.ErrInvalidArgument
}
addr = s.mapFamily(addr, family)
@@ -629,7 +646,7 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool
// Bind implements the linux syscall bind(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+func (s *socketOpsCommon) Bind(_ *kernel.Task, sockaddr []byte) *syserr.Error {
if len(sockaddr) < 2 {
return syserr.ErrInvalidArgument
}
@@ -647,23 +664,24 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
a.UnmarshalBytes(sockaddr[:sockAddrLinkSize])
- if a.Protocol != uint16(s.protocol) {
- return syserr.ErrInvalidArgument
- }
-
addr = tcpip.FullAddress{
NIC: tcpip.NICID(a.InterfaceIndex),
Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ Port: socket.Ntohs(a.Protocol),
}
} else {
+ if s.minSockAddrLen() > len(sockaddr) {
+ return syserr.ErrInvalidArgument
+ }
+
var err *syserr.Error
addr, family, err = socket.AddressAndFamily(sockaddr)
if err != nil {
return err
}
- if err = s.checkFamily(family, true /* exact */); err != nil {
- return err
+ if !s.checkFamily(family, true /* exact */) {
+ return syserr.ErrAddressFamilyNotSupported
}
addr = s.mapFamily(addr, family)
@@ -688,7 +706,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
// Listen implements the linux syscall listen(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
+func (s *socketOpsCommon) Listen(_ *kernel.Task, backlog int) *syserr.Error {
return syserr.TranslateNetstackError(s.Endpoint.Listen(backlog))
}
@@ -779,7 +797,7 @@ func ConvertShutdown(how int) (tcpip.ShutdownFlags, *syserr.Error) {
// Shutdown implements the linux syscall shutdown(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
+func (s *socketOpsCommon) Shutdown(_ *kernel.Task, how int) *syserr.Error {
f, err := ConvertShutdown(how)
if err != nil {
return err
@@ -860,7 +878,7 @@ func boolToInt32(v bool) int32 {
}
// getSockOptSocket implements GetSockOpt when level is SOL_SOCKET.
-func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) {
+func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, _ linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/124056281): Stop rejecting short optLen values in getsockopt.
switch name {
case linux.SO_ERROR:
@@ -1345,6 +1363,14 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress()))
return &v, nil
+ case linux.IPV6_RECVPKTINFO:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetIPv6ReceivePacketInfo()))
+ return &v, nil
+
case linux.IP6T_ORIGINAL_DST:
if outLen < sockAddrInet6Size {
return nil, syserr.ErrInvalidArgument
@@ -1368,11 +1394,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
- info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, true)
+ info, err := netfilter.GetInfo(t, stk.(*Stack).Stack, outPtr, true)
if err != nil {
return nil, err
}
@@ -1388,11 +1414,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
- entries, err := netfilter.GetEntries6(t, stack.(*Stack).Stack, outPtr, outLen)
+ entries, err := netfilter.GetEntries6(t, stk.(*Stack).Stack, outPtr, outLen)
if err != nil {
return nil, err
}
@@ -1408,8 +1434,8 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
ret, err := netfilter.TargetRevision(t, outPtr, header.IPv6ProtocolNumber)
@@ -1425,7 +1451,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
// getSockOptIP implements GetSockOpt when level is SOL_IP.
-func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr hostarch.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) {
+func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr hostarch.Addr, outLen int, _ int) (marshal.Marshallable, *syserr.Error) {
if _, ok := ep.(tcpip.Endpoint); !ok {
log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
return nil, syserr.ErrUnknownProtocolOption
@@ -1565,11 +1591,11 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
- info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, false)
+ info, err := netfilter.GetInfo(t, stk.(*Stack).Stack, outPtr, false)
if err != nil {
return nil, err
}
@@ -1585,11 +1611,11 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
- entries, err := netfilter.GetEntries4(t, stack.(*Stack).Stack, outPtr, outLen)
+ entries, err := netfilter.GetEntries4(t, stk.(*Stack).Stack, outPtr, outLen)
if err != nil {
return nil, err
}
@@ -1605,8 +1631,8 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
ret, err := netfilter.TargetRevision(t, outPtr, header.IPv4ProtocolNumber)
@@ -2046,7 +2072,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
if isTCPSocket(skType, skProto) && tcp.EndpointState(ep.State()) != tcp.StateInitial {
return syserr.ErrInvalidEndpointState
- } else if isUDPSocket(skType, skProto) && udp.EndpointState(ep.State()) != udp.StateInitial {
+ } else if isUDPSocket(skType, skProto) && transport.DatagramEndpointState(ep.State()) != transport.DatagramEndpointStateInitial {
return syserr.ErrInvalidEndpointState
}
@@ -2101,6 +2127,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0)
return nil
+ case linux.IPV6_RECVPKTINFO:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := int32(hostarch.ByteOrder.Uint32(optVal))
+
+ ep.SocketOptions().SetIPv6ReceivePacketInfo(v != 0)
+ return nil
+
case linux.IPV6_TCLASS:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -2143,12 +2178,12 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return syserr.ErrNoDevice
}
// Stack must be a netstack stack.
- return netfilter.SetEntries(t, stack.(*Stack).Stack, optVal, true)
+ return netfilter.SetEntries(t, stk.(*Stack).Stack, optVal, true)
case linux.IP6T_SO_SET_ADD_COUNTERS:
log.Infof("IP6T_SO_SET_ADD_COUNTERS is not supported")
@@ -2386,12 +2421,12 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return syserr.ErrNoDevice
}
// Stack must be a netstack stack.
- return netfilter.SetEntries(t, stack.(*Stack).Stack, optVal, false)
+ return netfilter.SetEntries(t, stk.(*Stack).Stack, optVal, false)
case linux.IPT_SO_SET_ADD_COUNTERS:
log.Infof("IPT_SO_SET_ADD_COUNTERS is not supported")
@@ -2490,7 +2525,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) {
linux.IPV6_RECVHOPLIMIT,
linux.IPV6_RECVHOPOPTS,
linux.IPV6_RECVPATHMTU,
- linux.IPV6_RECVPKTINFO,
linux.IPV6_RECVRTHDR,
linux.IPV6_RTHDR,
linux.IPV6_RTHDRDSTOPTS,
@@ -2559,7 +2593,7 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) {
// GetSockName implements the linux syscall getsockname(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetSockName(*kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.Endpoint.GetLocalAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -2571,7 +2605,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *
// GetPeerName implements the linux syscall getpeername(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetPeerName(*kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.Endpoint.GetRemoteAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -2716,6 +2750,8 @@ func (s *socketOpsCommon) controlMessages(cm tcpip.ControlMessages) socket.Contr
TClass: readCM.TClass,
HasIPPacketInfo: readCM.HasIPPacketInfo,
PacketInfo: readCM.PacketInfo,
+ HasIPv6PacketInfo: readCM.HasIPv6PacketInfo,
+ IPv6PacketInfo: readCM.IPv6PacketInfo,
OriginalDstAddress: readCM.OriginalDstAddress,
SockErr: readCM.SockErr,
},
@@ -2730,7 +2766,7 @@ func (s *socketOpsCommon) updateTimestamp(cm tcpip.ControlMessages) {
// Save the SIOCGSTAMP timestamp only if SO_TIMESTAMP is disabled.
if !s.sockOptTimestamp {
s.timestampValid = true
- s.timestampNS = cm.Timestamp
+ s.timestamp = cm.Timestamp
}
}
@@ -2789,7 +2825,7 @@ func (s *socketOpsCommon) recvErr(t *kernel.Task, dst usermem.IOSequence) (int,
// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, _ uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
if flags&linux.MSG_ERRQUEUE != 0 {
return s.recvErr(t, dst)
}
@@ -2873,8 +2909,8 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
if err != nil {
return 0, err
}
- if err := s.checkFamily(family, false /* exact */); err != nil {
- return 0, err
+ if !s.checkFamily(family, false /* exact */) {
+ return 0, syserr.ErrInvalidArgument
}
addrBuf = s.mapFamily(addrBuf, family)
@@ -2951,10 +2987,10 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy
s.readMu.Lock()
defer s.readMu.Unlock()
if !s.timestampValid {
- return 0, syserror.ENOENT
+ return 0, linuxerr.ENOENT
}
- tv := linux.NsecToTimeval(s.timestampNS)
+ tv := linux.NsecToTimeval(s.timestamp.UnixNano())
_, err := tv.CopyOut(t, args[2].Pointer())
return 0, err
@@ -3061,7 +3097,7 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
}
// interfaceIoctl implements interface requests.
-func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFReq) *syserr.Error {
+func interfaceIoctl(ctx context.Context, _ usermem.IO, arg int, ifr *linux.IFReq) *syserr.Error {
var (
iface inet.Interface
index int32
@@ -3069,8 +3105,8 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
)
// Find the relevant device.
- stack := inet.StackFromContext(ctx)
- if stack == nil {
+ stk := inet.StackFromContext(ctx)
+ if stk == nil {
return syserr.ErrNoDevice
}
@@ -3080,7 +3116,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
// Gets the name of the interface given the interface index
// stored in ifr_ifindex.
index = int32(hostarch.ByteOrder.Uint32(ifr.Data[:4]))
- if iface, ok := stack.Interfaces()[index]; ok {
+ if iface, ok := stk.Interfaces()[index]; ok {
ifr.SetName(iface.Name)
return nil
}
@@ -3088,7 +3124,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
}
// Find the relevant device.
- for index, iface = range stack.Interfaces() {
+ for index, iface = range stk.Interfaces() {
if iface.Name == ifr.Name() {
found = true
break
@@ -3121,7 +3157,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
}
case linux.SIOCGIFFLAGS:
- f, err := interfaceStatusFlags(stack, iface.Name)
+ f, err := interfaceStatusFlags(stk, iface.Name)
if err != nil {
return err
}
@@ -3131,7 +3167,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
case linux.SIOCGIFADDR:
// Copy the IPv4 address out.
- for _, addr := range stack.InterfaceAddrs()[index] {
+ for _, addr := range stk.InterfaceAddrs()[index] {
// This ioctl is only compatible with AF_INET addresses.
if addr.Family != linux.AF_INET {
continue
@@ -3167,7 +3203,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
case linux.SIOCGIFNETMASK:
// Gets the network mask of a device.
- for _, addr := range stack.InterfaceAddrs()[index] {
+ for _, addr := range stk.InterfaceAddrs()[index] {
// This ioctl is only compatible with AF_INET addresses.
if addr.Family != linux.AF_INET {
continue
@@ -3199,24 +3235,24 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
}
// ifconfIoctl populates a struct ifconf for the SIOCGIFCONF ioctl.
-func ifconfIoctl(ctx context.Context, t *kernel.Task, io usermem.IO, ifc *linux.IFConf) error {
+func ifconfIoctl(ctx context.Context, t *kernel.Task, _ usermem.IO, ifc *linux.IFConf) error {
// If Ptr is NULL, return the necessary buffer size via Len.
// Otherwise, write up to Len bytes starting at Ptr containing ifreq
// structs.
- stack := inet.StackFromContext(ctx)
- if stack == nil {
+ stk := inet.StackFromContext(ctx)
+ if stk == nil {
return syserr.ErrNoDevice.ToError()
}
if ifc.Ptr == 0 {
- ifc.Len = int32(len(stack.Interfaces())) * int32(linux.SizeOfIFReq)
+ ifc.Len = int32(len(stk.Interfaces())) * int32(linux.SizeOfIFReq)
return nil
}
max := ifc.Len
ifc.Len = 0
- for key, ifaceAddrs := range stack.InterfaceAddrs() {
- iface := stack.Interfaces()[key]
+ for key, ifaceAddrs := range stk.InterfaceAddrs() {
+ iface := stk.Interfaces()[key]
for _, ifaceAddr := range ifaceAddrs {
// Don't write past the end of the buffer.
if ifc.Len+int32(linux.SizeOfIFReq) > max {
@@ -3332,10 +3368,10 @@ func (s *socketOpsCommon) State() uint32 {
}
case isUDPSocket(s.skType, s.protocol):
// UDP socket.
- switch udp.EndpointState(s.Endpoint.State()) {
- case udp.StateInitial, udp.StateBound, udp.StateClosed:
+ switch transport.DatagramEndpointState(s.Endpoint.State()) {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateBound, transport.DatagramEndpointStateClosed:
return linux.TCP_CLOSE
- case udp.StateConnected:
+ case transport.DatagramEndpointStateConnected:
return linux.TCP_ESTABLISHED
default:
return 0
diff --git a/pkg/sentry/socket/netstack/netstack_state.go b/pkg/sentry/socket/netstack/netstack_state.go
new file mode 100644
index 000000000..591e00d42
--- /dev/null
+++ b/pkg/sentry/socket/netstack/netstack_state.go
@@ -0,0 +1,31 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netstack
+
+import (
+ "time"
+)
+
+func (s *socketOpsCommon) saveTimestamp() int64 {
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ return s.timestamp.UnixNano()
+}
+
+func (s *socketOpsCommon) loadTimestamp(nsec int64) {
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ s.timestamp = time.Unix(0, nsec)
+}
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index edc160b1b..3cdf29b80 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -27,7 +27,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
@@ -113,7 +112,7 @@ func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.
}
n, _, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false)
if err == syserr.ErrWouldBlock {
- return int64(n), syserror.ErrWouldBlock
+ return int64(n), linuxerr.ErrWouldBlock
}
if err != nil {
return 0, err.ToError()
@@ -132,14 +131,14 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs
r := src.Reader(ctx)
n, err := s.Endpoint.Write(r, tcpip.WriteOptions{})
if _, ok := err.(*tcpip.ErrWouldBlock); ok {
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
if err != nil {
return 0, syserr.TranslateNetstackError(err).ToError()
}
if n < src.NumBytes() {
- return n, syserror.ErrWouldBlock
+ return n, linuxerr.ErrWouldBlock
}
return n, nil
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index 208ab9909..ea199f223 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -155,7 +155,7 @@ func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
// Attach address to interface.
nicID := tcpip.NICID(idx)
- if err := s.Stack.AddProtocolAddressWithOptions(nicID, protocolAddress, stack.CanBePrimaryEndpoint); err != nil {
+ if err := s.Stack.AddProtocolAddress(nicID, protocolAddress, stack.AddressProperties{}); err != nil {
return syserr.TranslateNetstackError(err).ToError()
}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 658e90bb9..d4b80a39d 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -21,6 +21,7 @@ import (
"bytes"
"fmt"
"sync/atomic"
+ "time"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -51,8 +52,19 @@ type ControlMessages struct {
func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPacketInfo {
var p linux.ControlMessageIPPacketInfo
p.NIC = int32(packetInfo.NIC)
- copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
- copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
+ copy(p.LocalAddr[:], packetInfo.LocalAddr)
+ copy(p.DestinationAddr[:], packetInfo.DestinationAddr)
+ return p
+}
+
+// ipv6PacketInfoToLinux converts IPv6PacketInfo from tcpip format to Linux
+// format.
+func ipv6PacketInfoToLinux(packetInfo tcpip.IPv6PacketInfo) linux.ControlMessageIPv6PacketInfo {
+ var p linux.ControlMessageIPv6PacketInfo
+ if n := copy(p.Addr[:], packetInfo.Addr); n != len(p.Addr) {
+ panic(fmt.Sprintf("got copy(%x, %x) = %d, want = %d", p.Addr, packetInfo.Addr, n, len(p.Addr)))
+ }
+ p.NIC = uint32(packetInfo.NIC)
return p
}
@@ -114,7 +126,7 @@ func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessa
if cmgs.HasOriginalDstAddress {
orgDstAddr, _ = ConvertAddress(family, cmgs.OriginalDstAddress)
}
- return IPControlMessages{
+ cm := IPControlMessages{
HasTimestamp: cmgs.HasTimestamp,
Timestamp: cmgs.Timestamp,
HasInq: cmgs.HasInq,
@@ -125,9 +137,16 @@ func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessa
TClass: cmgs.TClass,
HasIPPacketInfo: cmgs.HasIPPacketInfo,
PacketInfo: packetInfoToLinux(cmgs.PacketInfo),
+ HasIPv6PacketInfo: cmgs.HasIPv6PacketInfo,
OriginalDstAddress: orgDstAddr,
SockErr: sockErrCmsgToLinux(cmgs.SockErr),
}
+
+ if cm.HasIPv6PacketInfo {
+ cm.IPv6PacketInfo = ipv6PacketInfoToLinux(cmgs.IPv6PacketInfo)
+ }
+
+ return cm
}
// IPControlMessages contains socket control messages for IP sockets.
@@ -138,9 +157,9 @@ type IPControlMessages struct {
// HasTimestamp indicates whether Timestamp is valid/set.
HasTimestamp bool
- // Timestamp is the time (in ns) that the last packet used to create
- // the read data was received.
- Timestamp int64
+ // Timestamp is the time that the last packet used to create the read data
+ // was received.
+ Timestamp time.Time `state:".(int64)"`
// HasInq indicates whether Inq is valid/set.
HasInq bool
@@ -166,6 +185,12 @@ type IPControlMessages struct {
// PacketInfo holds interface and address data on an incoming packet.
PacketInfo linux.ControlMessageIPPacketInfo
+ // HasIPv6PacketInfo indicates whether IPv6PacketInfo is set.
+ HasIPv6PacketInfo bool
+
+ // PacketInfo holds interface and address data on an incoming packet.
+ IPv6PacketInfo linux.ControlMessageIPv6PacketInfo
+
// OriginalDestinationAddress holds the original destination address
// and port of the incoming packet.
OriginalDstAddress linux.SockAddr
@@ -743,6 +768,8 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
a.UnmarshalUnsafe(addr[:sockAddrLinkSize])
+ // TODO(https://gvisor.dev/issue/6530): Do not assume all interfaces have
+ // an ethernet address.
if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize {
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
@@ -750,6 +777,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
return tcpip.FullAddress{
NIC: tcpip.NICID(a.InterfaceIndex),
Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ Port: Ntohs(a.Protocol),
}, family, nil
case linux.AF_UNSPEC:
diff --git a/pkg/sentry/socket/socket_state.go b/pkg/sentry/socket/socket_state.go
new file mode 100644
index 000000000..32e12b238
--- /dev/null
+++ b/pkg/sentry/socket/socket_state.go
@@ -0,0 +1,27 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package socket
+
+import (
+ "time"
+)
+
+func (i *IPControlMessages) saveTimestamp() int64 {
+ return i.Timestamp.UnixNano()
+}
+
+func (i *IPControlMessages) loadTimestamp(nsec int64) {
+ i.Timestamp = time.Unix(0, nsec)
+}
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index 5c3cdef6a..7b546c04d 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -62,7 +62,6 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
"//pkg/syserr",
- "//pkg/syserror",
"//pkg/tcpip",
"//pkg/usermem",
"//pkg/waiter",
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index 33f9aeb06..b3f0cf563 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -129,9 +129,9 @@ func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProv
stype: stype,
}
+ ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits)
ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */)
ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */)
- ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits)
return ep
}
@@ -406,14 +406,15 @@ func (e *connectionedEndpoint) Listen(backlog int) *syserr.Error {
// Accept accepts a new connection.
func (e *connectionedEndpoint) Accept(peerAddr *tcpip.FullAddress) (Endpoint, *syserr.Error) {
e.Lock()
- defer e.Unlock()
if !e.Listening() {
+ e.Unlock()
return nil, syserr.ErrInvalidEndpointState
}
select {
case ne := <-e.acceptedChan:
+ e.Unlock()
if peerAddr != nil {
ne.Lock()
c := ne.connected
@@ -429,6 +430,7 @@ func (e *connectionedEndpoint) Accept(peerAddr *tcpip.FullAddress) (Endpoint, *s
return ne, nil
default:
+ e.Unlock()
// Nothing left.
return nil, syserr.ErrWouldBlock
}
@@ -517,3 +519,6 @@ func (e *connectionedEndpoint) OnSetSendBufferSize(v int64) (newSz int64) {
}
return v
}
+
+// WakeupWriters implements tcpip.SocketOptionsHandler.WakeupWriters.
+func (e *connectionedEndpoint) WakeupWriters() {}
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index 61338728a..61311718e 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -44,9 +44,9 @@ func NewConnectionless(ctx context.Context) Endpoint {
q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: defaultBufferSize}
q.InitRefs()
ep.receiver = &queueReceiver{readQueue: &q}
+ ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits)
ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */)
ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */)
- ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits)
return ep
}
@@ -227,3 +227,6 @@ func (e *connectionlessEndpoint) OnSetSendBufferSize(v int64) (newSz int64) {
}
return v
}
+
+// WakeupWriters implements tcpip.SocketOptionsHandler.WakeupWriters.
+func (e *connectionlessEndpoint) WakeupWriters() {}
diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go
index e4de44498..188ad3bd9 100644
--- a/pkg/sentry/socket/unix/transport/queue.go
+++ b/pkg/sentry/socket/unix/transport/queue.go
@@ -59,12 +59,14 @@ func (q *queue) Close() {
// q.WriterQueue.Notify(waiter.WritableEvents)
func (q *queue) Reset(ctx context.Context) {
q.mu.Lock()
- for cur := q.dataList.Front(); cur != nil; cur = cur.Next() {
- cur.Release(ctx)
- }
+ dataList := q.dataList
q.dataList.Reset()
q.used = 0
q.mu.Unlock()
+
+ for cur := dataList.Front(); cur != nil; cur = cur.Next() {
+ cur.Release(ctx)
+ }
}
// DecRef implements RefCounter.DecRef.
@@ -133,7 +135,7 @@ func (q *queue) Enqueue(ctx context.Context, data [][]byte, c ControlMessages, f
free := q.limit - q.used
if l > free && truncate {
- if free == 0 {
+ if free <= 0 {
// Message can't fit right now.
q.mu.Unlock()
return 0, false, syserr.ErrWouldBlock
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 8ccdadae9..e9e482017 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -38,7 +38,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
@@ -494,7 +493,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
}
n, err := src.CopyInTo(t, &w)
- if err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ if err != linuxerr.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
return int(n), syserr.FromError(err)
}
@@ -514,13 +513,13 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
n, err = src.CopyInTo(t, &w)
total += n
- if err != syserror.ErrWouldBlock {
+ if err != linuxerr.ErrWouldBlock {
break
}
if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
- err = syserror.ErrWouldBlock
+ err = linuxerr.ErrWouldBlock
}
break
}
@@ -648,7 +647,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
var total int64
- if n, err := doRead(); err != syserror.ErrWouldBlock || dontWait {
+ if n, err := doRead(); err != linuxerr.ErrWouldBlock || dontWait {
var from linux.SockAddr
var fromLen uint32
if r.From != nil && len([]byte(r.From.Addr)) != 0 {
@@ -683,7 +682,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
defer s.EventUnregister(&e)
for {
- if n, err := doRead(); err != syserror.ErrWouldBlock {
+ if n, err := doRead(); err != linuxerr.ErrWouldBlock {
var from linux.SockAddr
var fromLen uint32
if r.From != nil {
diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go
index 757ff2a40..4d3f4d556 100644
--- a/pkg/sentry/strace/strace.go
+++ b/pkg/sentry/strace/strace.go
@@ -610,9 +610,9 @@ func (i *SyscallInfo) printExit(t *kernel.Task, elapsed time.Duration, output []
if err == nil {
// Fill in the output after successful execution.
i.post(t, args, retval, output, LogMaximumSize)
- rval = fmt.Sprintf("%#x (%v)", retval, elapsed)
+ rval = fmt.Sprintf("%d (%#x) (%v)", retval, retval, elapsed)
} else {
- rval = fmt.Sprintf("%#x errno=%d (%s) (%v)", retval, errno, err, elapsed)
+ rval = fmt.Sprintf("%d (%#x) errno=%d (%s) (%v)", retval, retval, errno, err, elapsed)
}
switch len(output) {
diff --git a/pkg/sentry/syscalls/BUILD b/pkg/sentry/syscalls/BUILD
index f2c55588f..7a7c80ac6 100644
--- a/pkg/sentry/syscalls/BUILD
+++ b/pkg/sentry/syscalls/BUILD
@@ -16,7 +16,6 @@ go_library(
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/epoll",
"//pkg/sentry/kernel/time",
- "//pkg/syserror",
"//pkg/waiter",
],
)
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index b5a371d9a..394396cde 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -104,7 +104,6 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/sync",
"//pkg/syserr",
- "//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go
index 76389fbe3..f4d549a3f 100644
--- a/pkg/sentry/syscalls/linux/error.go
+++ b/pkg/sentry/syscalls/linux/error.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
var (
@@ -90,9 +89,9 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, errOrig, intr er
}
// Translate error, if possible, to consolidate errors from other packages
- // into a smaller set of errors from syserror package.
+ // into a smaller set of errors from linuxerr package.
translatedErr := errOrig
- if errno, ok := syserror.TranslateError(errOrig); ok {
+ if errno, ok := linuxerr.TranslateError(errOrig); ok {
translatedErr = errno
}
switch {
@@ -167,10 +166,7 @@ func handleIOErrorImpl(ctx context.Context, partialResult bool, errOrig, intr er
// files. Since we have a partial read/write, we consume
// ErrWouldBlock, returning the partial result.
return true, nil
- }
-
- switch errOrig.(type) {
- case syserror.SyscallRestartErrno:
+ case linuxerr.IsRestartError(translatedErr):
// Identical to the EINTR case.
return true, nil
}
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index 1ead3c7e8..2046a48b9 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -23,7 +23,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/syscalls"
- "gvisor.dev/gvisor/pkg/syserror"
)
const (
@@ -124,7 +123,7 @@ var AMD64 = &kernel.SyscallTable{
68: syscalls.Supported("msgget", Msgget),
69: syscalls.Supported("msgsnd", Msgsnd),
70: syscalls.Supported("msgrcv", Msgrcv),
- 71: syscalls.PartiallySupported("msgctl", Msgctl, "Only supports IPC_RMID option.", []string{"gvisor.dev/issue/135"}),
+ 71: syscalls.Supported("msgctl", Msgctl),
72: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil),
73: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil),
74: syscalls.PartiallySupported("fsync", Fsync, "Full data flush is not guaranteed at this time.", nil),
@@ -175,8 +174,8 @@ var AMD64 = &kernel.SyscallTable{
119: syscalls.Supported("setresgid", Setresgid),
120: syscalls.Supported("getresgid", Getresgid),
121: syscalls.Supported("getpgid", Getpgid),
- 122: syscalls.ErrorWithEvent("setfsuid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
- 123: syscalls.ErrorWithEvent("setfsgid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
+ 122: syscalls.ErrorWithEvent("setfsuid", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
+ 123: syscalls.ErrorWithEvent("setfsgid", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
124: syscalls.Supported("getsid", Getsid),
125: syscalls.Supported("capget", Capget),
126: syscalls.Supported("capset", Capset),
@@ -187,12 +186,12 @@ var AMD64 = &kernel.SyscallTable{
131: syscalls.Supported("sigaltstack", Sigaltstack),
132: syscalls.Supported("utime", Utime),
133: syscalls.PartiallySupported("mknod", Mknod, "Device creation is not generally supported. Only regular file and FIFO creation are supported.", nil),
- 134: syscalls.Error("uselib", syserror.ENOSYS, "Obsolete", nil),
+ 134: syscalls.Error("uselib", linuxerr.ENOSYS, "Obsolete", nil),
135: syscalls.ErrorWithEvent("personality", linuxerr.EINVAL, "Unable to change personality.", nil),
- 136: syscalls.ErrorWithEvent("ustat", syserror.ENOSYS, "Needs filesystem support.", nil),
+ 136: syscalls.ErrorWithEvent("ustat", linuxerr.ENOSYS, "Needs filesystem support.", nil),
137: syscalls.PartiallySupported("statfs", Statfs, "Depends on the backing file system implementation.", nil),
138: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil),
- 139: syscalls.ErrorWithEvent("sysfs", syserror.ENOSYS, "", []string{"gvisor.dev/issue/165"}),
+ 139: syscalls.ErrorWithEvent("sysfs", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/165"}),
140: syscalls.PartiallySupported("getpriority", Getpriority, "Stub implementation.", nil),
141: syscalls.PartiallySupported("setpriority", Setpriority, "Stub implementation.", nil),
142: syscalls.CapError("sched_setparam", linux.CAP_SYS_NICE, "", nil),
@@ -230,15 +229,15 @@ var AMD64 = &kernel.SyscallTable{
174: syscalls.CapError("create_module", linux.CAP_SYS_MODULE, "", nil),
175: syscalls.CapError("init_module", linux.CAP_SYS_MODULE, "", nil),
176: syscalls.CapError("delete_module", linux.CAP_SYS_MODULE, "", nil),
- 177: syscalls.Error("get_kernel_syms", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil),
- 178: syscalls.Error("query_module", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil),
+ 177: syscalls.Error("get_kernel_syms", linuxerr.ENOSYS, "Not supported in Linux > 2.6.", nil),
+ 178: syscalls.Error("query_module", linuxerr.ENOSYS, "Not supported in Linux > 2.6.", nil),
179: syscalls.CapError("quotactl", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_admin for most operations
- 180: syscalls.Error("nfsservctl", syserror.ENOSYS, "Removed after Linux 3.1.", nil),
- 181: syscalls.Error("getpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil),
- 182: syscalls.Error("putpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil),
- 183: syscalls.Error("afs_syscall", syserror.ENOSYS, "Not implemented in Linux.", nil),
- 184: syscalls.Error("tuxcall", syserror.ENOSYS, "Not implemented in Linux.", nil),
- 185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil),
+ 180: syscalls.Error("nfsservctl", linuxerr.ENOSYS, "Removed after Linux 3.1.", nil),
+ 181: syscalls.Error("getpmsg", linuxerr.ENOSYS, "Not implemented in Linux.", nil),
+ 182: syscalls.Error("putpmsg", linuxerr.ENOSYS, "Not implemented in Linux.", nil),
+ 183: syscalls.Error("afs_syscall", linuxerr.ENOSYS, "Not implemented in Linux.", nil),
+ 184: syscalls.Error("tuxcall", linuxerr.ENOSYS, "Not implemented in Linux.", nil),
+ 185: syscalls.Error("security", linuxerr.ENOSYS, "Not implemented in Linux.", nil),
186: syscalls.Supported("gettid", Gettid),
187: syscalls.Supported("readahead", Readahead),
188: syscalls.PartiallySupported("setxattr", SetXattr, "Only supported for tmpfs.", nil),
@@ -258,18 +257,18 @@ var AMD64 = &kernel.SyscallTable{
202: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil),
203: syscalls.PartiallySupported("sched_setaffinity", SchedSetaffinity, "Stub implementation.", nil),
204: syscalls.PartiallySupported("sched_getaffinity", SchedGetaffinity, "Stub implementation.", nil),
- 205: syscalls.Error("set_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil),
+ 205: syscalls.Error("set_thread_area", linuxerr.ENOSYS, "Expected to return ENOSYS on 64-bit", nil),
206: syscalls.PartiallySupported("io_setup", IoSetup, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
207: syscalls.PartiallySupported("io_destroy", IoDestroy, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
208: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
209: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
210: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
- 211: syscalls.Error("get_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil),
+ 211: syscalls.Error("get_thread_area", linuxerr.ENOSYS, "Expected to return ENOSYS on 64-bit", nil),
212: syscalls.CapError("lookup_dcookie", linux.CAP_SYS_ADMIN, "", nil),
213: syscalls.Supported("epoll_create", EpollCreate),
- 214: syscalls.ErrorWithEvent("epoll_ctl_old", syserror.ENOSYS, "Deprecated.", nil),
- 215: syscalls.ErrorWithEvent("epoll_wait_old", syserror.ENOSYS, "Deprecated.", nil),
- 216: syscalls.ErrorWithEvent("remap_file_pages", syserror.ENOSYS, "Deprecated since Linux 3.16.", nil),
+ 214: syscalls.ErrorWithEvent("epoll_ctl_old", linuxerr.ENOSYS, "Deprecated.", nil),
+ 215: syscalls.ErrorWithEvent("epoll_wait_old", linuxerr.ENOSYS, "Deprecated.", nil),
+ 216: syscalls.ErrorWithEvent("remap_file_pages", linuxerr.ENOSYS, "Deprecated since Linux 3.16.", nil),
217: syscalls.Supported("getdents64", Getdents64),
218: syscalls.Supported("set_tid_address", SetTidAddress),
219: syscalls.Supported("restart_syscall", RestartSyscall),
@@ -289,16 +288,16 @@ var AMD64 = &kernel.SyscallTable{
233: syscalls.Supported("epoll_ctl", EpollCtl),
234: syscalls.Supported("tgkill", Tgkill),
235: syscalls.Supported("utimes", Utimes),
- 236: syscalls.Error("vserver", syserror.ENOSYS, "Not implemented by Linux", nil),
+ 236: syscalls.Error("vserver", linuxerr.ENOSYS, "Not implemented by Linux", nil),
237: syscalls.PartiallySupported("mbind", Mbind, "Stub implementation. Only a single NUMA node is advertised, and mempolicy is ignored accordingly, but mbind() will succeed and has effects reflected by get_mempolicy.", []string{"gvisor.dev/issue/262"}),
238: syscalls.PartiallySupported("set_mempolicy", SetMempolicy, "Stub implementation.", nil),
239: syscalls.PartiallySupported("get_mempolicy", GetMempolicy, "Stub implementation.", nil),
- 240: syscalls.ErrorWithEvent("mq_open", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 241: syscalls.ErrorWithEvent("mq_unlink", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 242: syscalls.ErrorWithEvent("mq_timedsend", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 243: syscalls.ErrorWithEvent("mq_timedreceive", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 244: syscalls.ErrorWithEvent("mq_notify", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 245: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 240: syscalls.ErrorWithEvent("mq_open", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 241: syscalls.ErrorWithEvent("mq_unlink", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 242: syscalls.ErrorWithEvent("mq_timedsend", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 243: syscalls.ErrorWithEvent("mq_timedreceive", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 244: syscalls.ErrorWithEvent("mq_notify", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 245: syscalls.ErrorWithEvent("mq_getsetattr", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
246: syscalls.CapError("kexec_load", linux.CAP_SYS_BOOT, "", nil),
247: syscalls.Supported("waitid", Waitid),
248: syscalls.Error("add_key", linuxerr.EACCES, "Not available to user.", nil),
@@ -331,7 +330,7 @@ var AMD64 = &kernel.SyscallTable{
275: syscalls.Supported("splice", Splice),
276: syscalls.Supported("tee", Tee),
277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil),
- 278: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
+ 278: syscalls.ErrorWithEvent("vmsplice", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
279: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly)
280: syscalls.Supported("utimensat", Utimensat),
281: syscalls.Supported("epoll_pwait", EpollPwait),
@@ -353,8 +352,8 @@ var AMD64 = &kernel.SyscallTable{
297: syscalls.Supported("rt_tgsigqueueinfo", RtTgsigqueueinfo),
298: syscalls.ErrorWithEvent("perf_event_open", linuxerr.ENODEV, "No support for perf counters", nil),
299: syscalls.PartiallySupported("recvmmsg", RecvMMsg, "Not all flags and control messages are supported.", nil),
- 300: syscalls.ErrorWithEvent("fanotify_init", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
- 301: syscalls.ErrorWithEvent("fanotify_mark", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
+ 300: syscalls.ErrorWithEvent("fanotify_init", linuxerr.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
+ 301: syscalls.ErrorWithEvent("fanotify_mark", linuxerr.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
302: syscalls.Supported("prlimit64", Prlimit64),
303: syscalls.Error("name_to_handle_at", linuxerr.EOPNOTSUPP, "Not supported by gVisor filesystems", nil),
304: syscalls.Error("open_by_handle_at", linuxerr.EOPNOTSUPP, "Not supported by gVisor filesystems", nil),
@@ -363,48 +362,48 @@ var AMD64 = &kernel.SyscallTable{
307: syscalls.PartiallySupported("sendmmsg", SendMMsg, "Not all flags and control messages are supported.", nil),
308: syscalls.ErrorWithEvent("setns", linuxerr.EOPNOTSUPP, "Needs filesystem support", []string{"gvisor.dev/issue/140"}), // TODO(b/29354995)
309: syscalls.Supported("getcpu", Getcpu),
- 310: syscalls.ErrorWithEvent("process_vm_readv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
- 311: syscalls.ErrorWithEvent("process_vm_writev", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
+ 310: syscalls.ErrorWithEvent("process_vm_readv", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
+ 311: syscalls.ErrorWithEvent("process_vm_writev", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
312: syscalls.CapError("kcmp", linux.CAP_SYS_PTRACE, "", nil),
313: syscalls.CapError("finit_module", linux.CAP_SYS_MODULE, "", nil),
- 314: syscalls.ErrorWithEvent("sched_setattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
- 315: syscalls.ErrorWithEvent("sched_getattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
- 316: syscalls.ErrorWithEvent("renameat2", syserror.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772)
+ 314: syscalls.ErrorWithEvent("sched_setattr", linuxerr.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
+ 315: syscalls.ErrorWithEvent("sched_getattr", linuxerr.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
+ 316: syscalls.ErrorWithEvent("renameat2", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772)
317: syscalls.Supported("seccomp", Seccomp),
318: syscalls.Supported("getrandom", GetRandom),
319: syscalls.Supported("memfd_create", MemfdCreate),
320: syscalls.CapError("kexec_file_load", linux.CAP_SYS_BOOT, "", nil),
321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil),
322: syscalls.Supported("execveat", Execveat),
- 323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
+ 323: syscalls.ErrorWithEvent("userfaultfd", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
324: syscalls.PartiallySupported("membarrier", Membarrier, "Not supported on all platforms.", nil),
325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
// Syscalls implemented after 325 are "backports" from versions
// of Linux after 4.4.
- 326: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil),
+ 326: syscalls.ErrorWithEvent("copy_file_range", linuxerr.ENOSYS, "", nil),
327: syscalls.Supported("preadv2", Preadv2),
328: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil),
- 329: syscalls.ErrorWithEvent("pkey_mprotect", syserror.ENOSYS, "", nil),
- 330: syscalls.ErrorWithEvent("pkey_alloc", syserror.ENOSYS, "", nil),
- 331: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil),
+ 329: syscalls.ErrorWithEvent("pkey_mprotect", linuxerr.ENOSYS, "", nil),
+ 330: syscalls.ErrorWithEvent("pkey_alloc", linuxerr.ENOSYS, "", nil),
+ 331: syscalls.ErrorWithEvent("pkey_free", linuxerr.ENOSYS, "", nil),
332: syscalls.Supported("statx", Statx),
- 333: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil),
+ 333: syscalls.ErrorWithEvent("io_pgetevents", linuxerr.ENOSYS, "", nil),
334: syscalls.PartiallySupported("rseq", RSeq, "Not supported on all platforms.", nil),
// Linux skips ahead to syscall 424 to sync numbers between arches.
- 424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil),
- 425: syscalls.ErrorWithEvent("io_uring_setup", syserror.ENOSYS, "", nil),
- 426: syscalls.ErrorWithEvent("io_uring_enter", syserror.ENOSYS, "", nil),
- 427: syscalls.ErrorWithEvent("io_uring_register", syserror.ENOSYS, "", nil),
- 428: syscalls.ErrorWithEvent("open_tree", syserror.ENOSYS, "", nil),
- 429: syscalls.ErrorWithEvent("move_mount", syserror.ENOSYS, "", nil),
- 430: syscalls.ErrorWithEvent("fsopen", syserror.ENOSYS, "", nil),
- 431: syscalls.ErrorWithEvent("fsconfig", syserror.ENOSYS, "", nil),
- 432: syscalls.ErrorWithEvent("fsmount", syserror.ENOSYS, "", nil),
- 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil),
- 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil),
- 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil),
+ 424: syscalls.ErrorWithEvent("pidfd_send_signal", linuxerr.ENOSYS, "", nil),
+ 425: syscalls.ErrorWithEvent("io_uring_setup", linuxerr.ENOSYS, "", nil),
+ 426: syscalls.ErrorWithEvent("io_uring_enter", linuxerr.ENOSYS, "", nil),
+ 427: syscalls.ErrorWithEvent("io_uring_register", linuxerr.ENOSYS, "", nil),
+ 428: syscalls.ErrorWithEvent("open_tree", linuxerr.ENOSYS, "", nil),
+ 429: syscalls.ErrorWithEvent("move_mount", linuxerr.ENOSYS, "", nil),
+ 430: syscalls.ErrorWithEvent("fsopen", linuxerr.ENOSYS, "", nil),
+ 431: syscalls.ErrorWithEvent("fsconfig", linuxerr.ENOSYS, "", nil),
+ 432: syscalls.ErrorWithEvent("fsmount", linuxerr.ENOSYS, "", nil),
+ 433: syscalls.ErrorWithEvent("fspick", linuxerr.ENOSYS, "", nil),
+ 434: syscalls.ErrorWithEvent("pidfd_open", linuxerr.ENOSYS, "", nil),
+ 435: syscalls.ErrorWithEvent("clone3", linuxerr.ENOSYS, "", nil),
441: syscalls.Supported("epoll_pwait2", EpollPwait2),
},
Emulate: map[hostarch.Addr]uintptr{
@@ -414,7 +413,7 @@ var AMD64 = &kernel.SyscallTable{
},
Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
t.Kernel().EmitUnimplementedEvent(t)
- return 0, syserror.ENOSYS
+ return 0, linuxerr.ENOSYS
},
}
@@ -472,7 +471,7 @@ var ARM64 = &kernel.SyscallTable{
39: syscalls.PartiallySupported("umount2", Umount2, "Not all options or file systems are supported.", nil),
40: syscalls.PartiallySupported("mount", Mount, "Not all options or file systems are supported.", nil),
41: syscalls.Error("pivot_root", linuxerr.EPERM, "", nil),
- 42: syscalls.Error("nfsservctl", syserror.ENOSYS, "Removed after Linux 3.1.", nil),
+ 42: syscalls.Error("nfsservctl", linuxerr.ENOSYS, "Removed after Linux 3.1.", nil),
43: syscalls.PartiallySupported("statfs", Statfs, "Depends on the backing file system implementation.", nil),
44: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil),
45: syscalls.Supported("truncate", Truncate),
@@ -505,7 +504,7 @@ var ARM64 = &kernel.SyscallTable{
72: syscalls.Supported("pselect", Pselect),
73: syscalls.Supported("ppoll", Ppoll),
74: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}),
- 75: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
+ 75: syscalls.ErrorWithEvent("vmsplice", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
76: syscalls.Supported("splice", Splice),
77: syscalls.Supported("tee", Tee),
78: syscalls.Supported("readlinkat", Readlinkat),
@@ -581,8 +580,8 @@ var ARM64 = &kernel.SyscallTable{
148: syscalls.Supported("getresuid", Getresuid),
149: syscalls.Supported("setresgid", Setresgid),
150: syscalls.Supported("getresgid", Getresgid),
- 151: syscalls.ErrorWithEvent("setfsuid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
- 152: syscalls.ErrorWithEvent("setfsgid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
+ 151: syscalls.ErrorWithEvent("setfsuid", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
+ 152: syscalls.ErrorWithEvent("setfsgid", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
153: syscalls.Supported("times", Times),
154: syscalls.Supported("setpgid", Setpgid),
155: syscalls.Supported("getpgid", Getpgid),
@@ -610,14 +609,14 @@ var ARM64 = &kernel.SyscallTable{
177: syscalls.Supported("getegid", Getegid),
178: syscalls.Supported("gettid", Gettid),
179: syscalls.PartiallySupported("sysinfo", Sysinfo, "Fields loads, sharedram, bufferram, totalswap, freeswap, totalhigh, freehigh not supported.", nil),
- 180: syscalls.ErrorWithEvent("mq_open", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 181: syscalls.ErrorWithEvent("mq_unlink", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 182: syscalls.ErrorWithEvent("mq_timedsend", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 183: syscalls.ErrorWithEvent("mq_timedreceive", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 184: syscalls.ErrorWithEvent("mq_notify", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 185: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 180: syscalls.ErrorWithEvent("mq_open", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 181: syscalls.ErrorWithEvent("mq_unlink", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 182: syscalls.ErrorWithEvent("mq_timedsend", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 183: syscalls.ErrorWithEvent("mq_timedreceive", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 184: syscalls.ErrorWithEvent("mq_notify", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 185: syscalls.ErrorWithEvent("mq_getsetattr", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
186: syscalls.Supported("msgget", Msgget),
- 187: syscalls.PartiallySupported("msgctl", Msgctl, "Only supports IPC_RMID option.", []string{"gvisor.dev/issue/135"}),
+ 187: syscalls.Supported("msgctl", Msgctl),
188: syscalls.Supported("msgrcv", Msgrcv),
189: syscalls.Supported("msgsnd", Msgsnd),
190: syscalls.Supported("semget", Semget),
@@ -664,7 +663,7 @@ var ARM64 = &kernel.SyscallTable{
231: syscalls.PartiallySupported("munlockall", Munlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
232: syscalls.PartiallySupported("mincore", Mincore, "Stub implementation. The sandbox does not have access to this information. Reports all mapped pages are resident.", nil),
233: syscalls.PartiallySupported("madvise", Madvise, "Options MADV_DONTNEED, MADV_DONTFORK are supported. Other advice is ignored.", nil),
- 234: syscalls.ErrorWithEvent("remap_file_pages", syserror.ENOSYS, "Deprecated since Linux 3.16.", nil),
+ 234: syscalls.ErrorWithEvent("remap_file_pages", linuxerr.ENOSYS, "Deprecated since Linux 3.16.", nil),
235: syscalls.PartiallySupported("mbind", Mbind, "Stub implementation. Only a single NUMA node is advertised, and mempolicy is ignored accordingly, but mbind() will succeed and has effects reflected by get_mempolicy.", []string{"gvisor.dev/issue/262"}),
236: syscalls.PartiallySupported("get_mempolicy", GetMempolicy, "Stub implementation.", nil),
237: syscalls.PartiallySupported("set_mempolicy", SetMempolicy, "Stub implementation.", nil),
@@ -676,60 +675,60 @@ var ARM64 = &kernel.SyscallTable{
243: syscalls.PartiallySupported("recvmmsg", RecvMMsg, "Not all flags and control messages are supported.", nil),
260: syscalls.Supported("wait4", Wait4),
261: syscalls.Supported("prlimit64", Prlimit64),
- 262: syscalls.ErrorWithEvent("fanotify_init", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
- 263: syscalls.ErrorWithEvent("fanotify_mark", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
+ 262: syscalls.ErrorWithEvent("fanotify_init", linuxerr.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
+ 263: syscalls.ErrorWithEvent("fanotify_mark", linuxerr.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
264: syscalls.Error("name_to_handle_at", linuxerr.EOPNOTSUPP, "Not supported by gVisor filesystems", nil),
265: syscalls.Error("open_by_handle_at", linuxerr.EOPNOTSUPP, "Not supported by gVisor filesystems", nil),
266: syscalls.CapError("clock_adjtime", linux.CAP_SYS_TIME, "", nil),
267: syscalls.PartiallySupported("syncfs", Syncfs, "Depends on backing file system.", nil),
268: syscalls.ErrorWithEvent("setns", linuxerr.EOPNOTSUPP, "Needs filesystem support", []string{"gvisor.dev/issue/140"}), // TODO(b/29354995)
269: syscalls.PartiallySupported("sendmmsg", SendMMsg, "Not all flags and control messages are supported.", nil),
- 270: syscalls.ErrorWithEvent("process_vm_readv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
- 271: syscalls.ErrorWithEvent("process_vm_writev", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
+ 270: syscalls.ErrorWithEvent("process_vm_readv", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
+ 271: syscalls.ErrorWithEvent("process_vm_writev", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
272: syscalls.CapError("kcmp", linux.CAP_SYS_PTRACE, "", nil),
273: syscalls.CapError("finit_module", linux.CAP_SYS_MODULE, "", nil),
- 274: syscalls.ErrorWithEvent("sched_setattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
- 275: syscalls.ErrorWithEvent("sched_getattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
- 276: syscalls.ErrorWithEvent("renameat2", syserror.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772)
+ 274: syscalls.ErrorWithEvent("sched_setattr", linuxerr.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
+ 275: syscalls.ErrorWithEvent("sched_getattr", linuxerr.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
+ 276: syscalls.ErrorWithEvent("renameat2", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772)
277: syscalls.Supported("seccomp", Seccomp),
278: syscalls.Supported("getrandom", GetRandom),
279: syscalls.Supported("memfd_create", MemfdCreate),
280: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil),
281: syscalls.Supported("execveat", Execveat),
- 282: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
+ 282: syscalls.ErrorWithEvent("userfaultfd", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
283: syscalls.PartiallySupported("membarrier", Membarrier, "Not supported on all platforms.", nil),
284: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
// Syscalls after 284 are "backports" from versions of Linux after 4.4.
- 285: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil),
+ 285: syscalls.ErrorWithEvent("copy_file_range", linuxerr.ENOSYS, "", nil),
286: syscalls.Supported("preadv2", Preadv2),
287: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil),
- 288: syscalls.ErrorWithEvent("pkey_mprotect", syserror.ENOSYS, "", nil),
- 289: syscalls.ErrorWithEvent("pkey_alloc", syserror.ENOSYS, "", nil),
- 290: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil),
+ 288: syscalls.ErrorWithEvent("pkey_mprotect", linuxerr.ENOSYS, "", nil),
+ 289: syscalls.ErrorWithEvent("pkey_alloc", linuxerr.ENOSYS, "", nil),
+ 290: syscalls.ErrorWithEvent("pkey_free", linuxerr.ENOSYS, "", nil),
291: syscalls.Supported("statx", Statx),
- 292: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil),
+ 292: syscalls.ErrorWithEvent("io_pgetevents", linuxerr.ENOSYS, "", nil),
293: syscalls.PartiallySupported("rseq", RSeq, "Not supported on all platforms.", nil),
// Linux skips ahead to syscall 424 to sync numbers between arches.
- 424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil),
- 425: syscalls.ErrorWithEvent("io_uring_setup", syserror.ENOSYS, "", nil),
- 426: syscalls.ErrorWithEvent("io_uring_enter", syserror.ENOSYS, "", nil),
- 427: syscalls.ErrorWithEvent("io_uring_register", syserror.ENOSYS, "", nil),
- 428: syscalls.ErrorWithEvent("open_tree", syserror.ENOSYS, "", nil),
- 429: syscalls.ErrorWithEvent("move_mount", syserror.ENOSYS, "", nil),
- 430: syscalls.ErrorWithEvent("fsopen", syserror.ENOSYS, "", nil),
- 431: syscalls.ErrorWithEvent("fsconfig", syserror.ENOSYS, "", nil),
- 432: syscalls.ErrorWithEvent("fsmount", syserror.ENOSYS, "", nil),
- 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil),
- 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil),
- 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil),
+ 424: syscalls.ErrorWithEvent("pidfd_send_signal", linuxerr.ENOSYS, "", nil),
+ 425: syscalls.ErrorWithEvent("io_uring_setup", linuxerr.ENOSYS, "", nil),
+ 426: syscalls.ErrorWithEvent("io_uring_enter", linuxerr.ENOSYS, "", nil),
+ 427: syscalls.ErrorWithEvent("io_uring_register", linuxerr.ENOSYS, "", nil),
+ 428: syscalls.ErrorWithEvent("open_tree", linuxerr.ENOSYS, "", nil),
+ 429: syscalls.ErrorWithEvent("move_mount", linuxerr.ENOSYS, "", nil),
+ 430: syscalls.ErrorWithEvent("fsopen", linuxerr.ENOSYS, "", nil),
+ 431: syscalls.ErrorWithEvent("fsconfig", linuxerr.ENOSYS, "", nil),
+ 432: syscalls.ErrorWithEvent("fsmount", linuxerr.ENOSYS, "", nil),
+ 433: syscalls.ErrorWithEvent("fspick", linuxerr.ENOSYS, "", nil),
+ 434: syscalls.ErrorWithEvent("pidfd_open", linuxerr.ENOSYS, "", nil),
+ 435: syscalls.ErrorWithEvent("clone3", linuxerr.ENOSYS, "", nil),
441: syscalls.Supported("epoll_pwait2", EpollPwait2),
},
Emulate: map[hostarch.Addr]uintptr{},
Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
t.Kernel().EmitUnimplementedEvent(t)
- return 0, syserror.ENOSYS
+ return 0, linuxerr.ENOSYS
},
}
diff --git a/pkg/sentry/syscalls/linux/sigset.go b/pkg/sentry/syscalls/linux/sigset.go
index 9dea78085..373948991 100644
--- a/pkg/sentry/syscalls/linux/sigset.go
+++ b/pkg/sentry/syscalls/linux/sigset.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/syserror"
)
// CopyInSigSet copies in a sigset_t, checks its size, and ensures that KILL and
@@ -67,6 +66,6 @@ func copyInSigSetWithSize(t *kernel.Task, addr hostarch.Addr) (hostarch.Addr, ui
maskSize := uint(hostarch.ByteOrder.Uint64(in[8:]))
return maskAddr, maskSize, nil
default:
- return 0, 0, syserror.ENOSYS
+ return 0, 0, linuxerr.ENOSYS
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go
index 4ce3430e2..2f00c3783 100644
--- a/pkg/sentry/syscalls/linux/sys_aio.go
+++ b/pkg/sentry/syscalls/linux/sys_aio.go
@@ -26,7 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/eventfd"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/mm"
- "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -138,7 +138,7 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
if count > 0 || linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
return uintptr(count), nil, nil
}
- return 0, nil, syserror.ConvertIntr(err, syserror.EINTR)
+ return 0, nil, syserr.ConvertIntr(err, linuxerr.EINTR)
}
}
@@ -216,7 +216,7 @@ func memoryFor(t *kernel.Task, cb *linux.IOCallback) (usermem.IOSequence, error)
// It is not presently supported (ENOSYS indicates no support on this
// architecture).
func IoCancel(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- return 0, nil, syserror.ENOSYS
+ return 0, nil, linuxerr.ENOSYS
}
// LINT.IfChange
@@ -355,7 +355,7 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
cbAddr = hostarch.Addr(cbAddrP)
default:
- return 0, nil, syserror.ENOSYS
+ return 0, nil, linuxerr.ENOSYS
}
// Copy in this callback.
diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go
index daa151bb4..6c807124c 100644
--- a/pkg/sentry/syscalls/linux/sys_epoll.go
+++ b/pkg/sentry/syscalls/linux/sys_epoll.go
@@ -22,7 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/epoll"
"gvisor.dev/gvisor/pkg/sentry/syscalls"
- "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -109,7 +109,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
func waitEpoll(t *kernel.Task, fd int32, eventsAddr hostarch.Addr, max int, timeoutInNanos int64) (uintptr, *kernel.SyscallControl, error) {
r, err := syscalls.WaitEpoll(t, fd, max, timeoutInNanos)
if err != nil {
- return 0, nil, syserror.ConvertIntr(err, syserror.EINTR)
+ return 0, nil, syserr.ConvertIntr(err, linuxerr.EINTR)
}
if len(r) != 0 {
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 3528d325f..e79b92fb6 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -30,7 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/fasync"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/limits"
- "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/syserr"
)
// fileOpAt performs an operation on the second last component in the path.
@@ -122,7 +122,7 @@ func copyInPath(t *kernel.Task, addr hostarch.Addr, allowEmpty bool) (path strin
return "", false, err
}
if path == "" && !allowEmpty {
- return "", false, syserror.ENOENT
+ return "", false, linuxerr.ENOENT
}
// If the path ends with a /, then checks must be enforced in various
@@ -162,7 +162,7 @@ func openAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, flags uint) (fd uin
if fs.IsDir(d.Inode.StableAttr) {
// Don't allow directories to be opened writable.
if fileFlags.Write {
- return syserror.EISDIR
+ return linuxerr.EISDIR
}
} else {
// If O_DIRECTORY is set, but the file is not a directory, then fail.
@@ -177,7 +177,7 @@ func openAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, flags uint) (fd uin
file, err := d.Inode.GetFile(t, d, fileFlags)
if err != nil {
- return syserror.ConvertIntr(err, syserror.ERESTARTSYS)
+ return syserr.ConvertIntr(err, linuxerr.ERESTARTSYS)
}
defer file.DecRef(t)
@@ -215,7 +215,7 @@ func mknodAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, mode linux.FileMod
return err
}
if dirPath {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
return fileOpAt(t, dirFD, path, func(root *fs.Dirent, d *fs.Dirent, name string, _ uint) error {
@@ -308,7 +308,7 @@ func createAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, flags uint, mode
return 0, err
}
if dirPath {
- return 0, syserror.ENOENT
+ return 0, linuxerr.ENOENT
}
fileFlags := linuxToFlags(flags)
@@ -416,7 +416,7 @@ func createAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, flags uint, mode
// Create a new fs.File.
newFile, err = found.Inode.GetFile(t, found, fileFlags)
if err != nil {
- return syserror.ConvertIntr(err, syserror.ERESTARTSYS)
+ return syserr.ConvertIntr(err, linuxerr.ERESTARTSYS)
}
defer newFile.DecRef(t)
case linuxerr.Equals(linuxerr.ENOENT, err):
@@ -795,7 +795,7 @@ func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
defer file.DecRef(t)
err := file.Flush(t)
- return 0, nil, handleIOError(t, false /* partial */, err, syserror.EINTR, "close", file)
+ return 0, nil, handleIOError(t, false /* partial */, err, linuxerr.EINTR, "close", file)
}
// Dup implements linux syscall dup(2).
@@ -1020,7 +1020,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
} else {
// Blocking lock, pass in the task to satisfy the lock.Blocker interface.
if !file.Dirent.Inode.LockCtx.Posix.LockRegionVFS1(t.FDTable(), lock.ReadLock, rng, t) {
- return 0, nil, syserror.EINTR
+ return 0, nil, linuxerr.EINTR
}
}
return 0, nil, nil
@@ -1036,7 +1036,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
} else {
// Blocking lock, pass in the task to satisfy the lock.Blocker interface.
if !file.Dirent.Inode.LockCtx.Posix.LockRegionVFS1(t.FDTable(), lock.WriteLock, rng, t) {
- return 0, nil, syserror.EINTR
+ return 0, nil, linuxerr.EINTR
}
}
return 0, nil, nil
@@ -1263,7 +1263,7 @@ func symlinkAt(t *kernel.Task, dirFD int32, newAddr hostarch.Addr, oldAddr hosta
return err
}
if dirPath {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
// The oldPath is copied in verbatim. This is because the symlink
@@ -1273,7 +1273,7 @@ func symlinkAt(t *kernel.Task, dirFD int32, newAddr hostarch.Addr, oldAddr hosta
return err
}
if oldPath == "" {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
return fileOpAt(t, dirFD, newPath, func(root *fs.Dirent, d *fs.Dirent, name string, _ uint) error {
@@ -1352,7 +1352,7 @@ func linkAt(t *kernel.Task, oldDirFD int32, oldAddr hostarch.Addr, newDirFD int3
return err
}
if dirPath {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if allowEmpty && oldPath == "" {
@@ -1439,7 +1439,7 @@ func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
allowEmpty := flags&linux.AT_EMPTY_PATH == linux.AT_EMPTY_PATH
if allowEmpty && !t.HasCapabilityIn(linux.CAP_DAC_READ_SEARCH, t.UserNamespace().Root()) {
- return 0, nil, syserror.ENOENT
+ return 0, nil, linuxerr.ENOENT
}
return 0, nil, linkAt(t, oldDirFD, oldAddr, newDirFD, newAddr, resolve, allowEmpty)
@@ -1455,7 +1455,7 @@ func readlinkAt(t *kernel.Task, dirFD int32, addr hostarch.Addr, bufAddr hostarc
return 0, err
}
if dirPath {
- return 0, syserror.ENOENT
+ return 0, linuxerr.ENOENT
}
err = fileOpOn(t, dirFD, path, false /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
@@ -1579,7 +1579,7 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
if fs.IsDir(d.Inode.StableAttr) {
- return syserror.EISDIR
+ return linuxerr.EISDIR
}
// In contrast to open(O_TRUNC), truncate(2) is only valid for file
// types.
@@ -2131,7 +2131,7 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, linuxerr.ESPIPE
}
if fs.IsDir(file.Dirent.Inode.StableAttr) {
- return 0, nil, syserror.EISDIR
+ return 0, nil, linuxerr.EISDIR
}
if !fs.IsRegular(file.Dirent.Inode.StableAttr) {
return 0, nil, linuxerr.ENODEV
@@ -2189,7 +2189,7 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
} else {
// Because we're blocking we will pass the task to satisfy the lock.Blocker interface.
if !file.Dirent.Inode.LockCtx.BSD.LockRegionVFS1(file, lock.WriteLock, rng, t) {
- return 0, nil, syserror.EINTR
+ return 0, nil, linuxerr.EINTR
}
}
case linux.LOCK_SH:
@@ -2201,7 +2201,7 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
} else {
// Because we're blocking we will pass the task to satisfy the lock.Blocker interface.
if !file.Dirent.Inode.LockCtx.BSD.LockRegionVFS1(file, lock.ReadLock, rng, t) {
- return 0, nil, syserror.EINTR
+ return 0, nil, linuxerr.EINTR
}
}
case linux.LOCK_UN:
diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go
index 717cec04d..bcdd7b633 100644
--- a/pkg/sentry/syscalls/linux/sys_futex.go
+++ b/pkg/sentry/syscalls/linux/sys_futex.go
@@ -23,7 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
- "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/syserr"
)
// futexWaitRestartBlock encapsulates the state required to restart futex(2)
@@ -75,7 +75,7 @@ func futexWaitAbsolute(t *kernel.Task, clockRealtime bool, ts linux.Timespec, fo
}
t.Futex().WaitComplete(w, t)
- return 0, syserror.ConvertIntr(err, syserror.ERESTARTSYS)
+ return 0, syserr.ConvertIntr(err, linuxerr.ERESTARTSYS)
}
// futexWaitDuration performs a FUTEX_WAIT, blocking until the wait is
@@ -103,7 +103,7 @@ func futexWaitDuration(t *kernel.Task, duration time.Duration, forever bool, add
// The wait was unsuccessful for some reason other than interruption. Simply
// forward the error.
- if err != syserror.ErrInterrupted {
+ if err != linuxerr.ErrInterrupted {
return 0, err
}
@@ -111,7 +111,7 @@ func futexWaitDuration(t *kernel.Task, duration time.Duration, forever bool, add
// The wait duration was absolute, restart with the original arguments.
if forever {
- return 0, syserror.ERESTARTSYS
+ return 0, linuxerr.ERESTARTSYS
}
// The wait duration was relative, restart with the remaining duration.
@@ -122,7 +122,7 @@ func futexWaitDuration(t *kernel.Task, duration time.Duration, forever bool, add
val: val,
mask: mask,
})
- return 0, syserror.ERESTART_RESTARTBLOCK
+ return 0, linuxerr.ERESTART_RESTARTBLOCK
}
func futexLockPI(t *kernel.Task, ts linux.Timespec, forever bool, addr hostarch.Addr, private bool) error {
@@ -150,7 +150,7 @@ func futexLockPI(t *kernel.Task, ts linux.Timespec, forever bool, addr hostarch.
}
t.Futex().WaitComplete(w, t)
- return syserror.ConvertIntr(err, syserror.ERESTARTSYS)
+ return syserr.ConvertIntr(err, linuxerr.ERESTARTSYS)
}
func tryLockPI(t *kernel.Task, addr hostarch.Addr, private bool) error {
@@ -280,11 +280,11 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.FUTEX_WAIT_REQUEUE_PI, linux.FUTEX_CMP_REQUEUE_PI:
t.Kernel().EmitUnimplementedEvent(t)
- return 0, nil, syserror.ENOSYS
+ return 0, nil, linuxerr.ENOSYS
default:
// We don't even know about this command.
- return 0, nil, syserror.ENOSYS
+ return 0, nil, linuxerr.ENOSYS
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go
index 917717e31..9f7a5ae8a 100644
--- a/pkg/sentry/syscalls/linux/sys_getdents.go
+++ b/pkg/sentry/syscalls/linux/sys_getdents.go
@@ -24,7 +24,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -83,7 +82,7 @@ func getdents(t *kernel.Task, fd int32, addr hostarch.Addr, size int, f func(*di
ds := newDirentSerializer(f, w, t.Arch(), size)
rerr := dir.Readdir(t, ds)
- switch err := handleIOError(t, ds.Written() > 0, rerr, syserror.ERESTARTSYS, "getdents", dir); err {
+ switch err := handleIOError(t, ds.Written() > 0, rerr, linuxerr.ERESTARTSYS, "getdents", dir); err {
case nil:
dir.Dirent.InotifyEvent(linux.IN_ACCESS, 0)
return uintptr(ds.Written()), nil
diff --git a/pkg/sentry/syscalls/linux/sys_lseek.go b/pkg/sentry/syscalls/linux/sys_lseek.go
index bf71a9af3..4a5712a29 100644
--- a/pkg/sentry/syscalls/linux/sys_lseek.go
+++ b/pkg/sentry/syscalls/linux/sys_lseek.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/syserror"
)
// LINT.IfChange
@@ -49,7 +48,7 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
offset, serr := file.Seek(t, sw, offset)
- err := handleIOError(t, false /* partialResult */, serr, syserror.ERESTARTSYS, "lseek", file)
+ err := handleIOError(t, false /* partialResult */, serr, linuxerr.ERESTARTSYS, "lseek", file)
if err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go
index cee621791..7efd17d40 100644
--- a/pkg/sentry/syscalls/linux/sys_mmap.go
+++ b/pkg/sentry/syscalls/linux/sys_mmap.go
@@ -24,7 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/mm"
- "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/syserr"
)
// Brk implements linux syscall brk(2).
@@ -211,7 +211,7 @@ func Madvise(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
case linux.MADV_REMOVE:
// These "suggestions" have application-visible side effects, so we
// have to indicate that we don't support them.
- return 0, nil, syserror.ENOSYS
+ return 0, nil, linuxerr.ENOSYS
case linux.MADV_HWPOISON:
// Only privileged processes are allowed to poison pages.
return 0, nil, linuxerr.EPERM
@@ -235,18 +235,18 @@ func Mincore(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// rounded up to the next multiple of the page size." - mincore(2)
la, ok := hostarch.Addr(length).RoundUp()
if !ok {
- return 0, nil, syserror.ENOMEM
+ return 0, nil, linuxerr.ENOMEM
}
ar, ok := addr.ToRange(uint64(la))
if !ok {
- return 0, nil, syserror.ENOMEM
+ return 0, nil, linuxerr.ENOMEM
}
// Pretend that all mapped pages are "resident in core".
mapped := t.MemoryManager().VirtualMemorySizeRange(ar)
// "ENOMEM: addr to addr + length contained unmapped memory."
if mapped != uint64(la) {
- return 0, nil, syserror.ENOMEM
+ return 0, nil, linuxerr.ENOMEM
}
resident := bytes.Repeat([]byte{1}, int(mapped/hostarch.PageSize))
_, err := t.CopyOutBytes(vec, resident)
@@ -277,7 +277,7 @@ func Msync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
})
// MSync calls fsync, the same interrupt conversion rules apply, see
// mm/msync.c, fsync POSIX.1-2008.
- return 0, nil, syserror.ConvertIntr(err, syserror.ERESTARTSYS)
+ return 0, nil, syserr.ConvertIntr(err, linuxerr.ERESTARTSYS)
}
// Mlock implements linux syscall mlock(2).
diff --git a/pkg/sentry/syscalls/linux/sys_msgqueue.go b/pkg/sentry/syscalls/linux/sys_msgqueue.go
index 5259ade90..60b989ee7 100644
--- a/pkg/sentry/syscalls/linux/sys_msgqueue.go
+++ b/pkg/sentry/syscalls/linux/sys_msgqueue.go
@@ -130,12 +130,63 @@ func receive(t *kernel.Task, id ipc.ID, mType int64, maxSize int64, msgCopy, wai
func Msgctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
id := ipc.ID(args[0].Int())
cmd := args[1].Int()
+ buf := args[2].Pointer()
creds := auth.CredentialsFromContext(t)
+ r := t.IPCNamespace().MsgqueueRegistry()
+
switch cmd {
+ case linux.IPC_INFO:
+ info := r.IPCInfo(t)
+ _, err := info.CopyOut(t, buf)
+ return 0, nil, err
+ case linux.MSG_INFO:
+ msgInfo := r.MsgInfo(t)
+ _, err := msgInfo.CopyOut(t, buf)
+ return 0, nil, err
case linux.IPC_RMID:
- return 0, nil, t.IPCNamespace().MsgqueueRegistry().Remove(id, creds)
+ return 0, nil, r.Remove(id, creds)
+ }
+
+ // Remaining commands use a queue.
+ queue, err := r.FindByID(id)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ switch cmd {
+ case linux.MSG_STAT:
+ // Technically, we should be treating id as "an index into the kernel's
+ // internal array that maintains information about all shared memory
+ // segments on the system". Since we don't track segments in an array,
+ // we'll just pretend the msqid is the index and do the same thing as
+ // IPC_STAT. Linux also uses the index as the msqid.
+ fallthrough
+ case linux.IPC_STAT:
+ stat, err := queue.Stat(t)
+ if err != nil {
+ return 0, nil, err
+ }
+ _, err = stat.CopyOut(t, buf)
+ return 0, nil, err
+
+ case linux.MSG_STAT_ANY:
+ stat, err := queue.StatAny(t)
+ if err != nil {
+ return 0, nil, err
+ }
+ _, err = stat.CopyOut(t, buf)
+ return 0, nil, err
+
+ case linux.IPC_SET:
+ var ds linux.MsqidDS
+ if _, err := ds.CopyIn(t, buf); err != nil {
+ return 0, nil, linuxerr.EINVAL
+ }
+ err := queue.Set(t, &ds)
+ return 0, nil, err
+
default:
return 0, nil, linuxerr.EINVAL
}
diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go
index a80c84fcd..ee4dbbc64 100644
--- a/pkg/sentry/syscalls/linux/sys_poll.go
+++ b/pkg/sentry/syscalls/linux/sys_poll.go
@@ -25,7 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/limits"
- "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -185,7 +185,7 @@ func doPoll(t *kernel.Task, addr hostarch.Addr, nfds uint, timeout time.Duration
pfd[i].Events |= linux.POLLHUP | linux.POLLERR
}
remainingTimeout, n, err := pollBlock(t, pfd, timeout)
- err = syserror.ConvertIntr(err, syserror.EINTR)
+ err = syserr.ConvertIntr(err, linuxerr.EINTR)
// The poll entries are copied out regardless of whether
// any are set or not. This aligns with the Linux behavior.
@@ -295,7 +295,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs hostarch.Ad
// Do the syscall, then count the number of bits set.
if _, _, err = pollBlock(t, pfd, timeout); err != nil {
- return 0, syserror.ConvertIntr(err, syserror.EINTR)
+ return 0, syserr.ConvertIntr(err, linuxerr.EINTR)
}
// r, w, and e are currently event mask bitsets; unset bits corresponding
@@ -411,7 +411,7 @@ func poll(t *kernel.Task, pfdAddr hostarch.Addr, nfds uint, timeout time.Duratio
nfds: nfds,
timeout: remainingTimeout,
})
- return 0, syserror.ERESTART_RESTARTBLOCK
+ return 0, linuxerr.ERESTART_RESTARTBLOCK
}
return n, err
}
@@ -465,7 +465,7 @@ func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// Note that this means that if err is nil but copyErr is not, copyErr is
// ignored. This is consistent with Linux.
if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil {
- err = syserror.ERESTARTNOHAND
+ err = linuxerr.ERESTARTNOHAND
}
return n, nil, err
}
@@ -495,7 +495,7 @@ func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr)
// See comment in Ppoll.
if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil {
- err = syserror.ERESTARTNOHAND
+ err = linuxerr.ERESTARTNOHAND
}
return n, nil, err
}
@@ -540,7 +540,7 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
// See comment in Ppoll.
if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil {
- err = syserror.ERESTARTNOHAND
+ err = linuxerr.ERESTARTNOHAND
}
return n, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go
index b54a3a11f..18ea23913 100644
--- a/pkg/sentry/syscalls/linux/sys_read.go
+++ b/pkg/sentry/syscalls/linux/sys_read.go
@@ -24,7 +24,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/socket"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -72,7 +71,7 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
n, err := readv(t, file, dst)
t.IOUsage().AccountReadSyscall(n)
- return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "read", file)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "read", file)
}
// Readahead implements readahead(2).
@@ -152,7 +151,7 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
n, err := preadv(t, file, dst, offset)
t.IOUsage().AccountReadSyscall(n)
- return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "pread64", file)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "pread64", file)
}
// Readv implements linux syscall readv(2).
@@ -182,7 +181,7 @@ func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
n, err := readv(t, file, dst)
t.IOUsage().AccountReadSyscall(n)
- return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "readv", file)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "readv", file)
}
// Preadv implements linux syscall preadv(2).
@@ -223,7 +222,7 @@ func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
n, err := preadv(t, file, dst, offset)
t.IOUsage().AccountReadSyscall(n)
- return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "preadv", file)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "preadv", file)
}
// Preadv2 implements linux syscall preadv2(2).
@@ -281,17 +280,17 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if offset == -1 {
n, err := readv(t, file, dst)
t.IOUsage().AccountReadSyscall(n)
- return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "preadv2", file)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "preadv2", file)
}
n, err := preadv(t, file, dst, offset)
t.IOUsage().AccountReadSyscall(n)
- return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "preadv2", file)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "preadv2", file)
}
func readv(t *kernel.Task, f *fs.File, dst usermem.IOSequence) (int64, error) {
n, err := f.Readv(t, dst)
- if err != syserror.ErrWouldBlock || f.Flags().NonBlocking {
+ if err != linuxerr.ErrWouldBlock || f.Flags().NonBlocking {
if n > 0 {
// Queue notification if we read anything.
f.Dirent.InotifyEvent(linux.IN_ACCESS, 0)
@@ -304,7 +303,7 @@ func readv(t *kernel.Task, f *fs.File, dst usermem.IOSequence) (int64, error) {
var deadline ktime.Time
if s, ok := f.FileOperations.(socket.Socket); ok {
dl := s.RecvTimeout()
- if dl < 0 && err == syserror.ErrWouldBlock {
+ if dl < 0 && err == linuxerr.ErrWouldBlock {
return n, err
}
if dl > 0 {
@@ -326,14 +325,14 @@ func readv(t *kernel.Task, f *fs.File, dst usermem.IOSequence) (int64, error) {
// other than "would block".
n, err = f.Readv(t, dst)
total += n
- if err != syserror.ErrWouldBlock {
+ if err != linuxerr.ErrWouldBlock {
break
}
// Wait for a notification that we should retry.
if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
- err = syserror.ErrWouldBlock
+ err = linuxerr.ErrWouldBlock
}
break
}
@@ -351,7 +350,7 @@ func readv(t *kernel.Task, f *fs.File, dst usermem.IOSequence) (int64, error) {
func preadv(t *kernel.Task, f *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
n, err := f.Preadv(t, dst, offset)
- if err != syserror.ErrWouldBlock || f.Flags().NonBlocking {
+ if err != linuxerr.ErrWouldBlock || f.Flags().NonBlocking {
if n > 0 {
// Queue notification if we read anything.
f.Dirent.InotifyEvent(linux.IN_ACCESS, 0)
@@ -372,7 +371,7 @@ func preadv(t *kernel.Task, f *fs.File, dst usermem.IOSequence, offset int64) (i
// other than "would block".
n, err = f.Preadv(t, dst, offset+total)
total += n
- if err != syserror.ErrWouldBlock {
+ if err != linuxerr.ErrWouldBlock {
break
}
diff --git a/pkg/sentry/syscalls/linux/sys_rlimit.go b/pkg/sentry/syscalls/linux/sys_rlimit.go
index a12e1c915..7210333d2 100644
--- a/pkg/sentry/syscalls/linux/sys_rlimit.go
+++ b/pkg/sentry/syscalls/linux/sys_rlimit.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/limits"
- "gvisor.dev/gvisor/pkg/syserror"
)
// rlimit describes an implementation of 'struct rlimit', which may vary from
@@ -44,7 +43,7 @@ func newRlimit(t *kernel.Task) (rlimit, error) {
// On 64-bit system, struct rlimit and struct rlimit64 are identical.
return &rlimit64{}, nil
default:
- return nil, syserror.ENOSYS
+ return nil, linuxerr.ENOSYS
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_rseq.go b/pkg/sentry/syscalls/linux/sys_rseq.go
index 5fe196647..8328a3742 100644
--- a/pkg/sentry/syscalls/linux/sys_rseq.go
+++ b/pkg/sentry/syscalls/linux/sys_rseq.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/syserror"
)
// RSeq implements syscall rseq(2).
@@ -33,7 +32,7 @@ func RSeq(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
// Event for applications that want rseq on a configuration
// that doesn't support them.
t.Kernel().EmitUnimplementedEvent(t)
- return 0, nil, syserror.ENOSYS
+ return 0, nil, linuxerr.ENOSYS
}
switch flags {
diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go
index f61cc466c..5a119b21c 100644
--- a/pkg/sentry/syscalls/linux/sys_sem.go
+++ b/pkg/sentry/syscalls/linux/sys_sem.go
@@ -23,7 +23,6 @@ import (
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/ipc"
@@ -166,8 +165,7 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, err
}
- perms := fs.FilePermsFromMode(linux.FileMode(s.SemPerm.Mode & 0777))
- return 0, nil, ipcSet(t, id, auth.UID(s.SemPerm.UID), auth.GID(s.SemPerm.GID), perms)
+ return 0, nil, ipcSet(t, id, &s)
case linux.GETPID:
v, err := getPID(t, id, num)
@@ -243,24 +241,13 @@ func remove(t *kernel.Task, id ipc.ID) error {
return r.Remove(id, creds)
}
-func ipcSet(t *kernel.Task, id ipc.ID, uid auth.UID, gid auth.GID, perms fs.FilePermissions) error {
+func ipcSet(t *kernel.Task, id ipc.ID, ds *linux.SemidDS) error {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
if set == nil {
return linuxerr.EINVAL
}
-
- creds := auth.CredentialsFromContext(t)
- kuid := creds.UserNamespace.MapToKUID(uid)
- if !kuid.Ok() {
- return linuxerr.EINVAL
- }
- kgid := creds.UserNamespace.MapToKGID(gid)
- if !kgid.Ok() {
- return linuxerr.EINVAL
- }
- owner := fs.FileOwner{UID: kuid, GID: kgid}
- return set.Change(t, creds, owner, perms)
+ return set.Set(t, ds)
}
func ipcStat(t *kernel.Task, id ipc.ID) (*linux.SemidDS, error) {
diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go
index 45608f3fa..03871d713 100644
--- a/pkg/sentry/syscalls/linux/sys_signal.go
+++ b/pkg/sentry/syscalls/linux/sys_signal.go
@@ -25,7 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/signalfd"
- "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/syserr"
)
// "For a process to have permission to send a signal it must
@@ -348,7 +348,7 @@ func Sigaltstack(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
// Pause implements linux syscall pause(2).
func Pause(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- return 0, nil, syserror.ConvertIntr(t.Block(nil), syserror.ERESTARTNOHAND)
+ return 0, nil, syserr.ConvertIntr(t.Block(nil), linuxerr.ERESTARTNOHAND)
}
// RtSigpending implements linux syscall rt_sigpending(2).
@@ -496,7 +496,7 @@ func RtSigsuspend(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
t.SetSavedSignalMask(oldmask)
// Perform the wait.
- return 0, nil, syserror.ConvertIntr(t.Block(nil), syserror.ERESTARTNOHAND)
+ return 0, nil, syserr.ConvertIntr(t.Block(nil), linuxerr.ERESTARTNOHAND)
}
// RestartSyscall implements the linux syscall restart_syscall(2).
@@ -512,7 +512,7 @@ func RestartSyscall(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
// function is never null by (re)initializing it with one that translates
// the restart into EINTR. We'll emulate that behaviour.
t.Debugf("Restart block missing in restart_syscall(2). Did ptrace inject a return value of ERESTART_RESTARTBLOCK?")
- return 0, nil, syserror.EINTR
+ return 0, nil, linuxerr.EINTR
}
// sharedSignalfd is shared between the two calls.
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 06eb8f319..50ddbc142 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -30,7 +30,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/control"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserr"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -260,7 +259,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
// Capture address and call syscall implementation.
@@ -270,7 +269,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
}
blocking := !file.Flags().NonBlocking
- return 0, nil, syserror.ConvertIntr(s.Connect(t, a, blocking).ToError(), syserror.ERESTARTSYS)
+ return 0, nil, syserr.ConvertIntr(s.Connect(t, a, blocking).ToError(), linuxerr.ERESTARTSYS)
}
// accept is the implementation of the accept syscall. It is called by accept
@@ -291,7 +290,7 @@ func accept(t *kernel.Task, fd int32, addr hostarch.Addr, addrLen hostarch.Addr,
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
if !ok {
- return 0, syserror.ENOTSOCK
+ return 0, linuxerr.ENOTSOCK
}
// Call the syscall implementation for this socket, then copy the
@@ -301,7 +300,7 @@ func accept(t *kernel.Task, fd int32, addr hostarch.Addr, addrLen hostarch.Addr,
peerRequested := addrLen != 0
nfd, peer, peerLen, e := s.Accept(t, peerRequested, flags, blocking)
if e != nil {
- return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS)
+ return 0, syserr.ConvertIntr(e.ToError(), linuxerr.ERESTARTSYS)
}
if peerRequested {
// NOTE(magi): Linux does not give you an error if it can't
@@ -350,7 +349,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
// Capture address and call syscall implementation.
@@ -377,7 +376,7 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
if backlog > maxListenBacklog {
@@ -415,7 +414,7 @@ func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
// Validate how, then call syscall implementation.
@@ -446,7 +445,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
// Read the length. Reject negative values.
@@ -527,7 +526,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
if optLen < 0 {
@@ -565,7 +564,7 @@ func GetSockName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
// Get the socket name and copy it to the caller.
@@ -593,7 +592,7 @@ func GetPeerName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
// Get the socket peer name and copy it to the caller.
@@ -626,7 +625,7 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
// Reject flags that we don't handle yet.
@@ -683,7 +682,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
if file.Flags().NonBlocking {
@@ -763,7 +762,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr hostarch.Addr, flags
if msg.ControlLen == 0 && msg.NameLen == 0 {
n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0)
if err != nil {
- return 0, syserror.ConvertIntr(err.ToError(), syserror.ERESTARTSYS)
+ return 0, syserr.ConvertIntr(err.ToError(), linuxerr.ERESTARTSYS)
}
if !cms.Unix.Empty() {
mflags |= linux.MSG_CTRUNC
@@ -785,7 +784,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr hostarch.Addr, flags
}
n, mflags, sender, senderLen, cms, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, msg.NameLen != 0, msg.ControlLen)
if e != nil {
- return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS)
+ return 0, syserr.ConvertIntr(e.ToError(), linuxerr.ERESTARTSYS)
}
defer cms.Release(t)
@@ -848,7 +847,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, fla
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
if !ok {
- return 0, syserror.ENOTSOCK
+ return 0, linuxerr.ENOTSOCK
}
if file.Flags().NonBlocking {
@@ -874,7 +873,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, fla
n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0)
cm.Release(t)
if e != nil {
- return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS)
+ return 0, syserr.ConvertIntr(e.ToError(), linuxerr.ERESTARTSYS)
}
// Copy the address to the caller.
@@ -921,7 +920,7 @@ func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
// Reject flags that we don't handle yet.
@@ -963,7 +962,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
// Reject flags that we don't handle yet.
@@ -1060,7 +1059,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr hostar
// Call the syscall implementation.
n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages)
- err = handleIOError(t, n != 0, e.ToError(), syserror.ERESTARTSYS, "sendmsg", file)
+ err = handleIOError(t, n != 0, e.ToError(), linuxerr.ERESTARTSYS, "sendmsg", file)
// Control messages should be released on error as well as for zero-length
// messages, which are discarded by the receiver.
if n == 0 || err != nil {
@@ -1087,7 +1086,7 @@ func sendTo(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
if !ok {
- return 0, syserror.ENOTSOCK
+ return 0, linuxerr.ENOTSOCK
}
if file.Flags().NonBlocking {
@@ -1122,7 +1121,7 @@ func sendTo(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags
// Call the syscall implementation.
n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: control.New(t, s, nil)})
- return uintptr(n), handleIOError(t, n != 0, e.ToError(), syserror.ERESTARTSYS, "sendto", file)
+ return uintptr(n), handleIOError(t, n != 0, e.ToError(), linuxerr.ERESTARTSYS, "sendto", file)
}
// SendTo implements the linux syscall sendto(2).
diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go
index 34d87ac1f..8c8847efa 100644
--- a/pkg/sentry/syscalls/linux/sys_splice.go
+++ b/pkg/sentry/syscalls/linux/sys_splice.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -46,9 +45,9 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB
for {
n, err = fs.Splice(t, outFile, inFile, opts)
- if n != 0 || err != syserror.ErrWouldBlock {
+ if n != 0 || err != linuxerr.ErrWouldBlock {
break
- } else if err == syserror.ErrWouldBlock && nonBlocking {
+ } else if err == linuxerr.ErrWouldBlock && nonBlocking {
break
}
@@ -177,7 +176,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// We can only pass a single file to handleIOError, so pick inFile
// arbitrarily. This is used only for debugging purposes.
- return uintptr(n), nil, handleIOError(t, false, err, syserror.ERESTARTSYS, "sendfile", inFile)
+ return uintptr(n), nil, handleIOError(t, false, err, linuxerr.ERESTARTSYS, "sendfile", inFile)
}
// Splice implements splice(2).
@@ -287,7 +286,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
// See above; inFile is chosen arbitrarily here.
- return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "splice", inFile)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "splice", inFile)
}
// Tee imlements tee(2).
@@ -340,5 +339,5 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
}
// See above; inFile is chosen arbitrarily here.
- return uintptr(n), nil, handleIOError(t, false, err, syserror.ERESTARTSYS, "tee", inFile)
+ return uintptr(n), nil, handleIOError(t, false, err, linuxerr.ERESTARTSYS, "tee", inFile)
}
diff --git a/pkg/sentry/syscalls/linux/sys_sync.go b/pkg/sentry/syscalls/linux/sys_sync.go
index 6278bef21..0c22599bf 100644
--- a/pkg/sentry/syscalls/linux/sys_sync.go
+++ b/pkg/sentry/syscalls/linux/sys_sync.go
@@ -20,7 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/syserr"
)
// LINT.IfChange
@@ -58,7 +58,7 @@ func Fsync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
defer file.DecRef(t)
err := file.Fsync(t, 0, fs.FileMaxOffset, fs.SyncAll)
- return 0, nil, syserror.ConvertIntr(err, syserror.ERESTARTSYS)
+ return 0, nil, syserr.ConvertIntr(err, linuxerr.ERESTARTSYS)
}
// Fdatasync implements linux syscall fdatasync(2).
@@ -74,7 +74,7 @@ func Fdatasync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
defer file.DecRef(t)
err := file.Fsync(t, 0, fs.FileMaxOffset, fs.SyncData)
- return 0, nil, syserror.ConvertIntr(err, syserror.ERESTARTSYS)
+ return 0, nil, syserr.ConvertIntr(err, linuxerr.ERESTARTSYS)
}
// SyncFileRange implements linux syscall sync_file_rage(2)
@@ -112,7 +112,7 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
if uflags&linux.SYNC_FILE_RANGE_WAIT_BEFORE != 0 &&
uflags&linux.SYNC_FILE_RANGE_WAIT_AFTER == 0 {
t.Kernel().EmitUnimplementedEvent(t)
- return 0, nil, syserror.ENOSYS
+ return 0, nil, linuxerr.ENOSYS
}
// SYNC_FILE_RANGE_WRITE initiates write-out of all dirty pages in the
@@ -137,7 +137,7 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
err = file.Fsync(t, offset, fs.FileMaxOffset, fs.SyncData)
}
- return 0, nil, syserror.ConvertIntr(err, syserror.ERESTARTSYS)
+ return 0, nil, syserr.ConvertIntr(err, linuxerr.ERESTARTSYS)
}
// LINT.ThenChange(vfs2/sync.go)
diff --git a/pkg/sentry/syscalls/linux/sys_syslog.go b/pkg/sentry/syscalls/linux/sys_syslog.go
index ba372f9e3..15acb2b8b 100644
--- a/pkg/sentry/syscalls/linux/sys_syslog.go
+++ b/pkg/sentry/syscalls/linux/sys_syslog.go
@@ -18,7 +18,6 @@ import (
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/syserror"
)
const (
@@ -57,6 +56,6 @@ func Syslog(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
case _SYSLOG_ACTION_SIZE_BUFFER:
return logBufLen, nil, nil
default:
- return 0, nil, syserror.ENOSYS
+ return 0, nil, linuxerr.ENOSYS
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go
index 981cdd985..d74173c56 100644
--- a/pkg/sentry/syscalls/linux/sys_thread.go
+++ b/pkg/sentry/syscalls/linux/sys_thread.go
@@ -27,7 +27,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
"gvisor.dev/gvisor/pkg/sentry/loader"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -111,7 +110,7 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr host
}
atEmptyPath := flags&linux.AT_EMPTY_PATH != 0
if !atEmptyPath && len(pathname) == 0 {
- return 0, nil, syserror.ENOENT
+ return 0, nil, linuxerr.ENOENT
}
resolveFinal := flags&linux.AT_SYMLINK_NOFOLLOW == 0
@@ -244,7 +243,7 @@ func parseCommonWaitOptions(wopts *kernel.WaitOptions, options int) error {
wopts.Events |= kernel.EventGroupContinue
}
if options&linux.WNOHANG == 0 {
- wopts.BlockInterruptErr = syserror.ERESTARTSYS
+ wopts.BlockInterruptErr = linuxerr.ERESTARTSYS
}
if options&linux.WNOTHREAD == 0 {
wopts.SiblingChildren = true
diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go
index 674e74f82..4adc8b8a4 100644
--- a/pkg/sentry/syscalls/linux/sys_time.go
+++ b/pkg/sentry/syscalls/linux/sys_time.go
@@ -25,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
- "gvisor.dev/gvisor/pkg/syserror"
)
// The most significant 29 bits hold either a pid or a file descriptor.
@@ -214,7 +213,7 @@ func clockNanosleepUntil(t *kernel.Task, c ktime.Clock, end ktime.Time, rem host
case linuxerr.Equals(linuxerr.ETIMEDOUT, err):
// Slept for entire timeout.
return nil
- case err == syserror.ErrInterrupted:
+ case err == linuxerr.ErrInterrupted:
// Interrupted.
remaining := end.Sub(c.Now())
if remaining <= 0 {
@@ -235,9 +234,9 @@ func clockNanosleepUntil(t *kernel.Task, c ktime.Clock, end ktime.Time, rem host
end: end,
rem: rem,
})
- return syserror.ERESTART_RESTARTBLOCK
+ return linuxerr.ERESTART_RESTARTBLOCK
}
- return syserror.ERESTARTNOHAND
+ return linuxerr.ERESTARTNOHAND
default:
panic(fmt.Sprintf("Impossible BlockWithTimer error %v", err))
}
diff --git a/pkg/sentry/syscalls/linux/sys_timer.go b/pkg/sentry/syscalls/linux/sys_timer.go
index 45eef4feb..d39a0a6f5 100644
--- a/pkg/sentry/syscalls/linux/sys_timer.go
+++ b/pkg/sentry/syscalls/linux/sys_timer.go
@@ -18,9 +18,9 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/syserror"
)
const nsecPerSec = int64(time.Second)
@@ -29,7 +29,7 @@ const nsecPerSec = int64(time.Second)
func Getitimer(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
if t.Arch().Width() != 8 {
// Definition of linux.ItimerVal assumes 64-bit architecture.
- return 0, nil, syserror.ENOSYS
+ return 0, nil, linuxerr.ENOSYS
}
timerID := args[0].Int()
@@ -51,7 +51,7 @@ func Getitimer(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
func Setitimer(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
if t.Arch().Width() != 8 {
// Definition of linux.ItimerVal assumes 64-bit architecture.
- return 0, nil, syserror.ENOSYS
+ return 0, nil, linuxerr.ENOSYS
}
timerID := args[0].Int()
diff --git a/pkg/sentry/syscalls/linux/sys_tls_amd64.go b/pkg/sentry/syscalls/linux/sys_tls_amd64.go
index 8c6cd7511..bde672d67 100644
--- a/pkg/sentry/syscalls/linux/sys_tls_amd64.go
+++ b/pkg/sentry/syscalls/linux/sys_tls_amd64.go
@@ -23,7 +23,6 @@ import (
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/syserror"
)
// ArchPrctl implements linux syscall arch_prctl(2).
@@ -39,7 +38,7 @@ func ArchPrctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, err
}
default:
- return 0, nil, syserror.ENOSYS
+ return 0, nil, linuxerr.ENOSYS
}
case linux.ARCH_SET_FS:
fsbase := args[1].Uint64()
diff --git a/pkg/sentry/syscalls/linux/sys_tls_arm64.go b/pkg/sentry/syscalls/linux/sys_tls_arm64.go
index ff4ac4d6d..dfa684387 100644
--- a/pkg/sentry/syscalls/linux/sys_tls_arm64.go
+++ b/pkg/sentry/syscalls/linux/sys_tls_arm64.go
@@ -18,12 +18,12 @@
package linux
import (
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/syserror"
)
// ArchPrctl is not defined for ARM64.
func ArchPrctl(*kernel.Task, arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- return 0, nil, syserror.ENOSYS
+ return 0, nil, linuxerr.ENOSYS
}
diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go
index 872168606..4a4ef5046 100644
--- a/pkg/sentry/syscalls/linux/sys_write.go
+++ b/pkg/sentry/syscalls/linux/sys_write.go
@@ -24,7 +24,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/socket"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -72,7 +71,7 @@ func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
n, err := writev(t, file, src)
t.IOUsage().AccountWriteSyscall(n)
- return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "write", file)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "write", file)
}
// Pwrite64 implements linux syscall pwrite64(2).
@@ -119,7 +118,7 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
n, err := pwritev(t, file, src, offset)
t.IOUsage().AccountWriteSyscall(n)
- return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "pwrite64", file)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "pwrite64", file)
}
// Writev implements linux syscall writev(2).
@@ -149,7 +148,7 @@ func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
n, err := writev(t, file, src)
t.IOUsage().AccountWriteSyscall(n)
- return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "writev", file)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "writev", file)
}
// Pwritev implements linux syscall pwritev(2).
@@ -190,7 +189,7 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
n, err := pwritev(t, file, src, offset)
t.IOUsage().AccountWriteSyscall(n)
- return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "pwritev", file)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "pwritev", file)
}
// Pwritev2 implements linux syscall pwritev2(2).
@@ -251,17 +250,17 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if offset == -1 {
n, err := writev(t, file, src)
t.IOUsage().AccountWriteSyscall(n)
- return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "pwritev2", file)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "pwritev2", file)
}
n, err := pwritev(t, file, src, offset)
t.IOUsage().AccountWriteSyscall(n)
- return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "pwritev2", file)
+ return uintptr(n), nil, handleIOError(t, n != 0, err, linuxerr.ERESTARTSYS, "pwritev2", file)
}
func writev(t *kernel.Task, f *fs.File, src usermem.IOSequence) (int64, error) {
n, err := f.Writev(t, src)
- if err != syserror.ErrWouldBlock || f.Flags().NonBlocking {
+ if err != linuxerr.ErrWouldBlock || f.Flags().NonBlocking {
if n > 0 {
// Queue notification if we wrote anything.
f.Dirent.InotifyEvent(linux.IN_MODIFY, 0)
@@ -274,7 +273,7 @@ func writev(t *kernel.Task, f *fs.File, src usermem.IOSequence) (int64, error) {
var deadline ktime.Time
if s, ok := f.FileOperations.(socket.Socket); ok {
dl := s.SendTimeout()
- if dl < 0 && err == syserror.ErrWouldBlock {
+ if dl < 0 && err == linuxerr.ErrWouldBlock {
return n, err
}
if dl > 0 {
@@ -296,14 +295,14 @@ func writev(t *kernel.Task, f *fs.File, src usermem.IOSequence) (int64, error) {
// anything other than "would block".
n, err = f.Writev(t, src)
total += n
- if err != syserror.ErrWouldBlock {
+ if err != linuxerr.ErrWouldBlock {
break
}
// Wait for a notification that we should retry.
if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
- err = syserror.ErrWouldBlock
+ err = linuxerr.ErrWouldBlock
}
break
}
@@ -321,7 +320,7 @@ func writev(t *kernel.Task, f *fs.File, src usermem.IOSequence) (int64, error) {
func pwritev(t *kernel.Task, f *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
n, err := f.Pwritev(t, src, offset)
- if err != syserror.ErrWouldBlock || f.Flags().NonBlocking {
+ if err != linuxerr.ErrWouldBlock || f.Flags().NonBlocking {
if n > 0 {
// Queue notification if we wrote anything.
f.Dirent.InotifyEvent(linux.IN_MODIFY, 0)
@@ -342,7 +341,7 @@ func pwritev(t *kernel.Task, f *fs.File, src usermem.IOSequence, offset int64) (
// anything other than "would block".
n, err = f.Pwritev(t, src, offset+total)
total += n
- if err != syserror.ErrWouldBlock {
+ if err != linuxerr.ErrWouldBlock {
break
}
diff --git a/pkg/sentry/syscalls/linux/timespec.go b/pkg/sentry/syscalls/linux/timespec.go
index b327e27d6..d90652a3f 100644
--- a/pkg/sentry/syscalls/linux/timespec.go
+++ b/pkg/sentry/syscalls/linux/timespec.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/syserror"
)
// copyTimespecIn copies a Timespec from the untrusted app range to the kernel.
@@ -38,7 +37,7 @@ func copyTimespecIn(t *kernel.Task, addr hostarch.Addr) (linux.Timespec, error)
ts.Nsec = int64(hostarch.ByteOrder.Uint64(in[8:]))
return ts, nil
default:
- return linux.Timespec{}, syserror.ENOSYS
+ return linux.Timespec{}, linuxerr.ENOSYS
}
}
@@ -52,7 +51,7 @@ func copyTimespecOut(t *kernel.Task, addr hostarch.Addr, ts *linux.Timespec) err
_, err := t.CopyOutBytes(addr, out)
return err
default:
- return syserror.ENOSYS
+ return linuxerr.ENOSYS
}
}
@@ -70,7 +69,7 @@ func copyTimevalIn(t *kernel.Task, addr hostarch.Addr) (linux.Timeval, error) {
tv.Usec = int64(hostarch.ByteOrder.Uint64(in[8:]))
return tv, nil
default:
- return linux.Timeval{}, syserror.ENOSYS
+ return linux.Timeval{}, linuxerr.ENOSYS
}
}
@@ -84,7 +83,7 @@ func copyTimevalOut(t *kernel.Task, addr hostarch.Addr, tv *linux.Timeval) error
_, err := t.CopyOutBytes(addr, out)
return err
default:
- return syserror.ENOSYS
+ return linuxerr.ENOSYS
}
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
index a73f096ff..1e3bd2a50 100644
--- a/pkg/sentry/syscalls/linux/vfs2/BUILD
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -73,7 +73,6 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/sync",
"//pkg/syserr",
- "//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
],
diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go
index a8fa86cdc..0b57c0f7c 100644
--- a/pkg/sentry/syscalls/linux/vfs2/aio.go
+++ b/pkg/sentry/syscalls/linux/vfs2/aio.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/mm"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -56,7 +55,7 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
cbAddr = hostarch.Addr(cbAddrP)
default:
- return 0, nil, syserror.ENOSYS
+ return 0, nil, linuxerr.ENOSYS
}
// Copy in this callback.
diff --git a/pkg/sentry/syscalls/linux/vfs2/execve.go b/pkg/sentry/syscalls/linux/vfs2/execve.go
index 38818c175..fcf2e25de 100644
--- a/pkg/sentry/syscalls/linux/vfs2/execve.go
+++ b/pkg/sentry/syscalls/linux/vfs2/execve.go
@@ -25,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/loader"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
)
// Execve implements linux syscall execve(2).
@@ -83,7 +82,7 @@ func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr host
// do_open_execat(fd=AT_FDCWD)), and the loader package is currently
// incapable of handling this correctly.
if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
- return 0, nil, syserror.ENOENT
+ return 0, nil, linuxerr.ENOENT
}
dirfile, dirfileFlags := t.FDTable().GetVFS2(dirfd)
if dirfile == nil {
diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go
index 2cfb12cad..2198aa065 100644
--- a/pkg/sentry/syscalls/linux/vfs2/fd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/fd.go
@@ -25,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
)
// Close implements Linux syscall close(2).
@@ -42,7 +41,7 @@ func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
defer file.DecRef(t)
err := file.OnClose(t)
- return 0, nil, slinux.HandleIOErrorVFS2(t, false /* partial */, err, syserror.EINTR, "close", file)
+ return 0, nil, slinux.HandleIOErrorVFS2(t, false /* partial */, err, linuxerr.EINTR, "close", file)
}
// Dup implements Linux syscall dup(2).
diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
index 534355237..f19f0fd41 100644
--- a/pkg/sentry/syscalls/linux/vfs2/filesystem.go
+++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
)
// Link implements Linux syscall link(2).
@@ -46,7 +45,7 @@ func linkat(t *kernel.Task, olddirfd int32, oldpathAddr hostarch.Addr, newdirfd
return linuxerr.EINVAL
}
if flags&linux.AT_EMPTY_PATH != 0 && !t.HasCapability(linux.CAP_DAC_READ_SEARCH) {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
oldpath, err := copyInPath(t, oldpathAddr)
@@ -320,7 +319,7 @@ func symlinkat(t *kernel.Task, targetAddr hostarch.Addr, newdirfd int32, linkpat
return err
}
if len(target) == 0 {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
linkpath, err := copyInPath(t, linkpathAddr)
if err != nil {
diff --git a/pkg/sentry/syscalls/linux/vfs2/path.go b/pkg/sentry/syscalls/linux/vfs2/path.go
index 2bb783a85..38796d4db 100644
--- a/pkg/sentry/syscalls/linux/vfs2/path.go
+++ b/pkg/sentry/syscalls/linux/vfs2/path.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
)
func copyInPath(t *kernel.Task, addr hostarch.Addr) (fspath.Path, error) {
@@ -44,7 +43,7 @@ func getTaskPathOperation(t *kernel.Task, dirfd int32, path fspath.Path, shouldA
if !path.Absolute {
if !path.HasComponents() && !bool(shouldAllowEmptyPath) {
root.DecRef(t)
- return taskPathOperation{}, syserror.ENOENT
+ return taskPathOperation{}, linuxerr.ENOENT
}
if dirfd == linux.AT_FDCWD {
start = t.FSContext().WorkingDirectoryVFS2()
diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go
index 042aa4c97..204051cd0 100644
--- a/pkg/sentry/syscalls/linux/vfs2/poll.go
+++ b/pkg/sentry/syscalls/linux/vfs2/poll.go
@@ -20,15 +20,14 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/waiter"
-
- "gvisor.dev/gvisor/pkg/hostarch"
)
// fileCap is the maximum allowable files for poll & select. This has no
@@ -189,7 +188,7 @@ func doPoll(t *kernel.Task, addr hostarch.Addr, nfds uint, timeout time.Duration
pfd[i].Events |= linux.POLLHUP | linux.POLLERR
}
remainingTimeout, n, err := pollBlock(t, pfd, timeout)
- err = syserror.ConvertIntr(err, syserror.EINTR)
+ err = syserr.ConvertIntr(err, linuxerr.EINTR)
// The poll entries are copied out regardless of whether
// any are set or not. This aligns with the Linux behavior.
@@ -299,7 +298,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs hostarch.Ad
// Do the syscall, then count the number of bits set.
if _, _, err = pollBlock(t, pfd, timeout); err != nil {
- return 0, syserror.ConvertIntr(err, syserror.EINTR)
+ return 0, syserr.ConvertIntr(err, linuxerr.EINTR)
}
// r, w, and e are currently event mask bitsets; unset bits corresponding
@@ -417,7 +416,7 @@ func poll(t *kernel.Task, pfdAddr hostarch.Addr, nfds uint, timeout time.Duratio
nfds: nfds,
timeout: remainingTimeout,
})
- return 0, syserror.ERESTART_RESTARTBLOCK
+ return 0, linuxerr.ERESTART_RESTARTBLOCK
}
return n, err
}
@@ -464,7 +463,7 @@ func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// Note that this means that if err is nil but copyErr is not, copyErr is
// ignored. This is consistent with Linux.
if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil {
- err = syserror.ERESTARTNOHAND
+ err = linuxerr.ERESTARTNOHAND
}
return n, nil, err
}
@@ -494,7 +493,7 @@ func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr)
// See comment in Ppoll.
if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil {
- err = syserror.ERESTARTNOHAND
+ err = linuxerr.ERESTARTNOHAND
}
return n, nil, err
}
@@ -541,7 +540,7 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
// See comment in Ppoll.
if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil {
- err = syserror.ERESTARTNOHAND
+ err = linuxerr.ERESTARTNOHAND
}
return n, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go
index fe8aa06da..4e7dc5080 100644
--- a/pkg/sentry/syscalls/linux/vfs2/read_write.go
+++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go
@@ -25,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -63,7 +62,7 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
n, err := read(t, file, dst, vfs.ReadOptions{})
t.IOUsage().AccountReadSyscall(n)
- return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "read", file)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "read", file)
}
// Readv implements Linux syscall readv(2).
@@ -88,12 +87,12 @@ func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
n, err := read(t, file, dst, vfs.ReadOptions{})
t.IOUsage().AccountReadSyscall(n)
- return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "readv", file)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "readv", file)
}
func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
n, err := file.Read(t, dst, opts)
- if err != syserror.ErrWouldBlock {
+ if err != linuxerr.ErrWouldBlock {
return n, err
}
@@ -115,14 +114,14 @@ func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opt
// "would block".
n, err = file.Read(t, dst, opts)
total += n
- if err != syserror.ErrWouldBlock {
+ if err != linuxerr.ErrWouldBlock {
break
}
// Wait for a notification that we should retry.
if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil {
if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
- err = syserror.ErrWouldBlock
+ err = linuxerr.ErrWouldBlock
}
break
}
@@ -166,7 +165,7 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
n, err := pread(t, file, dst, offset, vfs.ReadOptions{})
t.IOUsage().AccountReadSyscall(n)
- return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "pread64", file)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "pread64", file)
}
// Preadv implements Linux syscall preadv(2).
@@ -197,7 +196,7 @@ func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
n, err := pread(t, file, dst, offset, vfs.ReadOptions{})
t.IOUsage().AccountReadSyscall(n)
- return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "preadv", file)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "preadv", file)
}
// Preadv2 implements Linux syscall preadv2(2).
@@ -243,12 +242,12 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
n, err = pread(t, file, dst, offset, opts)
}
t.IOUsage().AccountReadSyscall(n)
- return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "preadv2", file)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "preadv2", file)
}
func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
n, err := file.PRead(t, dst, offset, opts)
- if err != syserror.ErrWouldBlock {
+ if err != linuxerr.ErrWouldBlock {
return n, err
}
@@ -270,14 +269,14 @@ func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, of
// "would block".
n, err = file.PRead(t, dst, offset+total, opts)
total += n
- if err != syserror.ErrWouldBlock {
+ if err != linuxerr.ErrWouldBlock {
break
}
// Wait for a notification that we should retry.
if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil {
if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
- err = syserror.ErrWouldBlock
+ err = linuxerr.ErrWouldBlock
}
break
}
@@ -314,7 +313,7 @@ func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
n, err := write(t, file, src, vfs.WriteOptions{})
t.IOUsage().AccountWriteSyscall(n)
- return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "write", file)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "write", file)
}
// Writev implements Linux syscall writev(2).
@@ -339,12 +338,12 @@ func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
n, err := write(t, file, src, vfs.WriteOptions{})
t.IOUsage().AccountWriteSyscall(n)
- return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "writev", file)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "writev", file)
}
func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
n, err := file.Write(t, src, opts)
- if err != syserror.ErrWouldBlock {
+ if err != linuxerr.ErrWouldBlock {
return n, err
}
@@ -366,14 +365,14 @@ func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, op
// "would block".
n, err = file.Write(t, src, opts)
total += n
- if err != syserror.ErrWouldBlock {
+ if err != linuxerr.ErrWouldBlock {
break
}
// Wait for a notification that we should retry.
if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil {
if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
- err = syserror.ErrWouldBlock
+ err = linuxerr.ErrWouldBlock
}
break
}
@@ -416,7 +415,7 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
n, err := pwrite(t, file, src, offset, vfs.WriteOptions{})
t.IOUsage().AccountWriteSyscall(n)
- return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "pwrite64", file)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "pwrite64", file)
}
// Pwritev implements Linux syscall pwritev(2).
@@ -447,7 +446,7 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
n, err := pwrite(t, file, src, offset, vfs.WriteOptions{})
t.IOUsage().AccountReadSyscall(n)
- return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "pwritev", file)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "pwritev", file)
}
// Pwritev2 implements Linux syscall pwritev2(2).
@@ -493,12 +492,12 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
n, err = pwrite(t, file, src, offset, opts)
}
t.IOUsage().AccountWriteSyscall(n)
- return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "pwritev2", file)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "pwritev2", file)
}
func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
n, err := file.PWrite(t, src, offset, opts)
- if err != syserror.ErrWouldBlock {
+ if err != linuxerr.ErrWouldBlock {
return n, err
}
@@ -520,14 +519,14 @@ func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, o
// "would block".
n, err = file.PWrite(t, src, offset+total, opts)
total += n
- if err != syserror.ErrWouldBlock {
+ if err != linuxerr.ErrWouldBlock {
break
}
// Wait for a notification that we should retry.
if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil {
if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
- err = syserror.ErrWouldBlock
+ err = linuxerr.ErrWouldBlock
}
break
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go
index b5a3b92c5..e608572b4 100644
--- a/pkg/sentry/syscalls/linux/vfs2/setstat.go
+++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go
@@ -24,7 +24,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
)
const chmodMask = 0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX
@@ -432,7 +431,7 @@ func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPa
start := root
if !path.Absolute {
if !path.HasComponents() && !bool(shouldAllowEmptyPath) {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if dirfd == linux.AT_FDCWD {
start = t.FSContext().WorkingDirectoryVFS2()
@@ -465,7 +464,7 @@ func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPa
}
func handleSetSizeError(t *kernel.Task, err error) error {
- if err == syserror.ErrExceedsFileSizeLimit {
+ if err == linuxerr.ErrExceedsFileSizeLimit {
// Convert error to EFBIG and send a SIGXFSZ per setrlimit(2).
t.SendSignal(kernel.SignalInfoNoInfo(linux.SIGXFSZ, t, t))
return linuxerr.EFBIG
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
index 0c2e0720b..48be5a88d 100644
--- a/pkg/sentry/syscalls/linux/vfs2/socket.go
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -30,10 +31,7 @@ import (
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
-
- "gvisor.dev/gvisor/pkg/hostarch"
)
// maxAddrLen is the maximum socket address length we're willing to accept.
@@ -264,7 +262,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
// Capture address and call syscall implementation.
@@ -274,7 +272,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
}
blocking := (file.StatusFlags() & linux.SOCK_NONBLOCK) == 0
- return 0, nil, syserror.ConvertIntr(s.Connect(t, a, blocking).ToError(), syserror.ERESTARTSYS)
+ return 0, nil, syserr.ConvertIntr(s.Connect(t, a, blocking).ToError(), linuxerr.ERESTARTSYS)
}
// accept is the implementation of the accept syscall. It is called by accept
@@ -295,7 +293,7 @@ func accept(t *kernel.Task, fd int32, addr hostarch.Addr, addrLen hostarch.Addr,
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
if !ok {
- return 0, syserror.ENOTSOCK
+ return 0, linuxerr.ENOTSOCK
}
// Call the syscall implementation for this socket, then copy the
@@ -305,7 +303,7 @@ func accept(t *kernel.Task, fd int32, addr hostarch.Addr, addrLen hostarch.Addr,
peerRequested := addrLen != 0
nfd, peer, peerLen, e := s.Accept(t, peerRequested, flags, blocking)
if e != nil {
- return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS)
+ return 0, syserr.ConvertIntr(e.ToError(), linuxerr.ERESTARTSYS)
}
if peerRequested {
// NOTE(magi): Linux does not give you an error if it can't
@@ -354,7 +352,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
// Capture address and call syscall implementation.
@@ -381,7 +379,7 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
if backlog > maxListenBacklog {
@@ -419,7 +417,7 @@ func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
// Validate how, then call syscall implementation.
@@ -450,7 +448,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
// Read the length. Reject negative values.
@@ -531,7 +529,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
if optLen < 0 {
@@ -569,7 +567,7 @@ func GetSockName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
// Get the socket name and copy it to the caller.
@@ -597,7 +595,7 @@ func GetPeerName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
// Get the socket peer name and copy it to the caller.
@@ -630,7 +628,7 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
// Reject flags that we don't handle yet.
@@ -687,7 +685,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
@@ -767,7 +765,7 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr hostarch.Addr, fl
if msg.ControlLen == 0 && msg.NameLen == 0 {
n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0)
if err != nil {
- return 0, syserror.ConvertIntr(err.ToError(), syserror.ERESTARTSYS)
+ return 0, syserr.ConvertIntr(err.ToError(), linuxerr.ERESTARTSYS)
}
if !cms.Unix.Empty() {
mflags |= linux.MSG_CTRUNC
@@ -789,7 +787,7 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr hostarch.Addr, fl
}
n, mflags, sender, senderLen, cms, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, msg.NameLen != 0, msg.ControlLen)
if e != nil {
- return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS)
+ return 0, syserr.ConvertIntr(e.ToError(), linuxerr.ERESTARTSYS)
}
defer cms.Release(t)
@@ -852,7 +850,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, fla
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
if !ok {
- return 0, syserror.ENOTSOCK
+ return 0, linuxerr.ENOTSOCK
}
if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
@@ -878,7 +876,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, fla
n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0)
cm.Release(t)
if e != nil {
- return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS)
+ return 0, syserr.ConvertIntr(e.ToError(), linuxerr.ERESTARTSYS)
}
// Copy the address to the caller.
@@ -925,7 +923,7 @@ func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
// Reject flags that we don't handle yet.
@@ -967,7 +965,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
if !ok {
- return 0, nil, syserror.ENOTSOCK
+ return 0, nil, linuxerr.ENOTSOCK
}
// Reject flags that we don't handle yet.
@@ -1064,7 +1062,7 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio
// Call the syscall implementation.
n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages)
- err = slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), syserror.ERESTARTSYS, "sendmsg", file)
+ err = slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), linuxerr.ERESTARTSYS, "sendmsg", file)
// Control messages should be released on error as well as for zero-length
// messages, which are discarded by the receiver.
if n == 0 || err != nil {
@@ -1091,7 +1089,7 @@ func sendTo(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags
// Extract the socket.
s, ok := file.Impl().(socket.SocketVFS2)
if !ok {
- return 0, syserror.ENOTSOCK
+ return 0, linuxerr.ENOTSOCK
}
if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
@@ -1126,7 +1124,7 @@ func sendTo(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags
// Call the syscall implementation.
n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: control.New(t, s, nil)})
- return uintptr(n), slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), syserror.ERESTARTSYS, "sendto", file)
+ return uintptr(n), slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), linuxerr.ERESTARTSYS, "sendto", file)
}
// SendTo implements the linux syscall sendto(2).
diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go
index d8009123f..0205f09e0 100644
--- a/pkg/sentry/syscalls/linux/vfs2/splice.go
+++ b/pkg/sentry/syscalls/linux/vfs2/splice.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -151,7 +150,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
panic("at least one end of splice must be a pipe")
}
- if n != 0 || err != syserror.ErrWouldBlock || nonBlock {
+ if n != 0 || err != linuxerr.ErrWouldBlock || nonBlock {
break
}
if err = dw.waitForBoth(t); err != nil {
@@ -173,7 +172,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// We can only pass a single file to handleIOError, so pick inFile arbitrarily.
// This is used only for debugging purposes.
- return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "splice", outFile)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "splice", outFile)
}
// Tee implements Linux syscall tee(2).
@@ -241,7 +240,7 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
defer dw.destroy()
for {
n, err = pipe.Tee(t, outPipeFD, inPipeFD, count)
- if n != 0 || err != syserror.ErrWouldBlock || nonBlock {
+ if n != 0 || err != linuxerr.ErrWouldBlock || nonBlock {
break
}
if err = dw.waitForBoth(t); err != nil {
@@ -251,7 +250,7 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
if n != 0 {
// If a partial write is completed, the error is dropped. Log it here.
- if err != nil && err != io.EOF && err != syserror.ErrWouldBlock {
+ if err != nil && err != io.EOF && err != linuxerr.ErrWouldBlock {
log.Debugf("tee completed a partial write with error: %v", err)
err = nil
}
@@ -259,7 +258,7 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
// We can only pass a single file to handleIOError, so pick inFile arbitrarily.
// This is used only for debugging purposes.
- return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "tee", inFile)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, linuxerr.ERESTARTSYS, "tee", inFile)
}
// Sendfile implements linux system call sendfile(2).
@@ -360,10 +359,10 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
break
}
if err == nil && t.Interrupted() {
- err = syserror.ErrInterrupted
+ err = linuxerr.ErrInterrupted
break
}
- if err == syserror.ErrWouldBlock && !nonBlock {
+ if err == linuxerr.ErrWouldBlock && !nonBlock {
err = dw.waitForBoth(t)
}
if err != nil {
@@ -389,7 +388,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
var writeN int64
writeN, err = outFile.Write(t, usermem.BytesIOSequence(wbuf), vfs.WriteOptions{})
wbuf = wbuf[writeN:]
- if err == syserror.ErrWouldBlock && !nonBlock {
+ if err == linuxerr.ErrWouldBlock && !nonBlock {
err = dw.waitForOut(t)
}
if err != nil {
@@ -420,10 +419,10 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
break
}
if err == nil && t.Interrupted() {
- err = syserror.ErrInterrupted
+ err = linuxerr.ErrInterrupted
break
}
- if err == syserror.ErrWouldBlock && !nonBlock {
+ if err == linuxerr.ErrWouldBlock && !nonBlock {
err = dw.waitForBoth(t)
}
if err != nil {
@@ -441,7 +440,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
if total != 0 {
- if err != nil && err != io.EOF && err != syserror.ErrWouldBlock {
+ if err != nil && err != io.EOF && err != linuxerr.ErrWouldBlock {
// If a partial write is completed, the error is dropped. Log it here.
log.Debugf("sendfile completed a partial write with error: %v", err)
err = nil
@@ -450,7 +449,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// We can only pass a single file to handleIOError, so pick inFile arbitrarily.
// This is used only for debugging purposes.
- return uintptr(total), nil, slinux.HandleIOErrorVFS2(t, total != 0, err, syserror.ERESTARTSYS, "sendfile", inFile)
+ return uintptr(total), nil, slinux.HandleIOErrorVFS2(t, total != 0, err, linuxerr.ERESTARTSYS, "sendfile", inFile)
}
// dualWaiter is used to wait on one or both vfs.FileDescriptions. It is not
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go
index ba1d30823..adaf8db3f 100644
--- a/pkg/sentry/syscalls/linux/vfs2/stat.go
+++ b/pkg/sentry/syscalls/linux/vfs2/stat.go
@@ -25,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
)
// Stat implements Linux syscall stat(2).
@@ -70,7 +69,7 @@ func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr hostarch.Addr, flag
start := root
if !path.Absolute {
if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if dirfd == linux.AT_FDCWD {
start = t.FSContext().WorkingDirectoryVFS2()
@@ -182,7 +181,7 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
start := root
if !path.Absolute {
if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
- return 0, nil, syserror.ENOENT
+ return 0, nil, linuxerr.ENOENT
}
if dirfd == linux.AT_FDCWD {
start = t.FSContext().WorkingDirectoryVFS2()
diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go
index d0ffc7c32..cfc693422 100644
--- a/pkg/sentry/syscalls/linux/vfs2/sync.go
+++ b/pkg/sentry/syscalls/linux/vfs2/sync.go
@@ -19,7 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/syserr"
)
// Sync implements Linux syscall sync(2).
@@ -108,12 +108,12 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
if flags&linux.SYNC_FILE_RANGE_WAIT_BEFORE != 0 &&
flags&linux.SYNC_FILE_RANGE_WAIT_AFTER == 0 {
t.Kernel().EmitUnimplementedEvent(t)
- return 0, nil, syserror.ENOSYS
+ return 0, nil, linuxerr.ENOSYS
}
if flags&linux.SYNC_FILE_RANGE_WAIT_AFTER != 0 {
if err := file.Sync(t); err != nil {
- return 0, nil, syserror.ConvertIntr(err, syserror.ERESTARTSYS)
+ return 0, nil, syserr.ConvertIntr(err, linuxerr.ERESTARTSYS)
}
}
return 0, nil, nil
diff --git a/pkg/sentry/syscalls/syscalls.go b/pkg/sentry/syscalls/syscalls.go
index 511fb8b28..cfcc21271 100644
--- a/pkg/sentry/syscalls/syscalls.go
+++ b/pkg/sentry/syscalls/syscalls.go
@@ -31,7 +31,6 @@ import (
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/syserror"
)
// Supported returns a syscall that is fully supported.
@@ -103,10 +102,10 @@ func CapError(name string, c linux.Capability, note string, urls []string) kerne
return 0, nil, linuxerr.EPERM
}
t.Kernel().EmitUnimplementedEvent(t)
- return 0, nil, syserror.ENOSYS
+ return 0, nil, linuxerr.ENOSYS
},
SupportLevel: kernel.SupportUnimplemented,
- Note: fmt.Sprintf("%sReturns %q if the process does not have %s; %q otherwise.", note, linuxerr.EPERM, c.String(), syserror.ENOSYS),
+ Note: fmt.Sprintf("%sReturns %q if the process does not have %s; %q otherwise.", note, linuxerr.EPERM, c.String(), linuxerr.ENOSYS),
URLs: urls,
}
}
diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD
index 36d999c47..c21971322 100644
--- a/pkg/sentry/time/BUILD
+++ b/pkg/sentry/time/BUILD
@@ -39,7 +39,6 @@ go_library(
"//pkg/log",
"//pkg/metric",
"//pkg/sync",
- "//pkg/syserror",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/time/sampler_arm64.go b/pkg/sentry/time/sampler_arm64.go
index 3560e66ae..9b8c9a480 100644
--- a/pkg/sentry/time/sampler_arm64.go
+++ b/pkg/sentry/time/sampler_arm64.go
@@ -30,9 +30,9 @@ func getDefaultArchOverheadCycles() TSCValue {
// frqRatio. defaultOverheadCycles of ARM equals to that on
// x86 devided by frqRatio
cntfrq := getCNTFRQ()
- frqRatio := 1000000000 / cntfrq
+ frqRatio := 1000000000 / float64(cntfrq)
overheadCycles := (1 * 1000) / frqRatio
- return overheadCycles
+ return TSCValue(overheadCycles)
}
// defaultOverheadTSC is the default estimated syscall overhead in TSC cycles.
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index a2032162d..914574543 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -116,7 +116,6 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/uniqueid",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
"@org_golang_x_sys//unix:go_default_library",
@@ -137,7 +136,6 @@ go_test(
"//pkg/errors/linuxerr",
"//pkg/sentry/contexttest",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/vfs/README.md b/pkg/sentry/vfs/README.md
index 5aad31b78..82ee2c521 100644
--- a/pkg/sentry/vfs/README.md
+++ b/pkg/sentry/vfs/README.md
@@ -1,9 +1,5 @@
# The gVisor Virtual Filesystem
-THIS PACKAGE IS CURRENTLY EXPERIMENTAL AND NOT READY OR ENABLED FOR PRODUCTION
-USE. For the filesystem implementation currently used by gVisor, see the `fs`
-package.
-
## Implementation Notes
### Reference Counting
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
index 23ccc6b66..fefd0fc9c 100644
--- a/pkg/sentry/vfs/epoll.go
+++ b/pkg/sentry/vfs/epoll.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -263,7 +262,7 @@ func (ep *EpollInstance) ModifyInterest(file *FileDescription, num int32, event
num: num,
}]
if !ok {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
// Update epi for the next call to ep.ReadEvents().
@@ -299,7 +298,7 @@ func (ep *EpollInstance) DeleteInterest(file *FileDescription, num int32) error
num: num,
}]
if !ok {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
// Unregister from the file so that epi will no longer be readied.
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index a875fdeca..452f5f1f9 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -17,6 +17,7 @@ package vfs
import (
"bytes"
"io"
+ "math"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -25,7 +26,6 @@ import (
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -56,7 +56,7 @@ func (FileDescriptionDefaultImpl) OnClose(ctx context.Context) error {
// StatFS implements FileDescriptionImpl.StatFS analogously to
// super_operations::statfs == NULL in Linux.
func (FileDescriptionDefaultImpl) StatFS(ctx context.Context) (linux.Statfs, error) {
- return linux.Statfs{}, syserror.ENOSYS
+ return linux.Statfs{}, linuxerr.ENOSYS
}
// Allocate implements FileDescriptionImpl.Allocate analogously to
@@ -175,27 +175,27 @@ type DirectoryFileDescriptionDefaultImpl struct{}
// Allocate implements DirectoryFileDescriptionDefaultImpl.Allocate.
func (DirectoryFileDescriptionDefaultImpl) Allocate(ctx context.Context, mode, offset, length uint64) error {
- return syserror.EISDIR
+ return linuxerr.EISDIR
}
// PRead implements FileDescriptionImpl.PRead.
func (DirectoryFileDescriptionDefaultImpl) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) {
- return 0, syserror.EISDIR
+ return 0, linuxerr.EISDIR
}
// Read implements FileDescriptionImpl.Read.
func (DirectoryFileDescriptionDefaultImpl) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) {
- return 0, syserror.EISDIR
+ return 0, linuxerr.EISDIR
}
// PWrite implements FileDescriptionImpl.PWrite.
func (DirectoryFileDescriptionDefaultImpl) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) {
- return 0, syserror.EISDIR
+ return 0, linuxerr.EISDIR
}
// Write implements FileDescriptionImpl.Write.
func (DirectoryFileDescriptionDefaultImpl) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) {
- return 0, syserror.EISDIR
+ return 0, linuxerr.EISDIR
}
// DentryMetadataFileDescriptionImpl may be embedded by implementations of
@@ -368,7 +368,7 @@ func (fd *DynamicBytesFileDescriptionImpl) pwriteLocked(ctx context.Context, src
writable, ok := fd.data.(WritableDynamicBytesSource)
if !ok {
- return 0, syserror.EIO
+ return 0, linuxerr.EIO
}
n, err := writable.Write(ctx, src, offset)
if err != nil {
@@ -400,6 +400,9 @@ func (fd *DynamicBytesFileDescriptionImpl) Write(ctx context.Context, src userme
// GenericConfigureMMap may be used by most implementations of
// FileDescriptionImpl.ConfigureMMap.
func GenericConfigureMMap(fd *FileDescription, m memmap.Mappable, opts *memmap.MMapOpts) error {
+ if opts.Offset+opts.Length > math.MaxInt64 {
+ return linuxerr.EOVERFLOW
+ }
opts.Mappable = m
opts.MappingIdentity = fd
fd.IncRef()
diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go
index 3423dede1..e34a8c11b 100644
--- a/pkg/sentry/vfs/file_description_impl_util_test.go
+++ b/pkg/sentry/vfs/file_description_impl_util_test.go
@@ -25,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/contexttest"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -157,10 +156,10 @@ func TestGenCountFD(t *testing.T) {
// Write and PWrite fails.
if _, err := fd.Write(ctx, ioseq, WriteOptions{}); !linuxerr.Equals(linuxerr.EIO, err) {
- t.Errorf("Write: got err %v, wanted %v", err, syserror.EIO)
+ t.Errorf("Write: got err %v, wanted %v", err, linuxerr.EIO)
}
if _, err := fd.PWrite(ctx, ioseq, 0, WriteOptions{}); !linuxerr.Equals(linuxerr.EIO, err) {
- t.Errorf("Write: got err %v, wanted %v", err, syserror.EIO)
+ t.Errorf("Write: got err %v, wanted %v", err, linuxerr.EIO)
}
}
diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go
index 088beb8e2..17d94b341 100644
--- a/pkg/sentry/vfs/inotify.go
+++ b/pkg/sentry/vfs/inotify.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -209,7 +208,7 @@ func (i *Inotify) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOpt
if i.events.Empty() {
// Nothing to read yet, tell caller to block.
- return 0, syserror.ErrWouldBlock
+ return 0, linuxerr.ErrWouldBlock
}
var writeLen int64
diff --git a/pkg/sentry/vfs/lock.go b/pkg/sentry/vfs/lock.go
index cbe4d8c2d..1853cdca0 100644
--- a/pkg/sentry/vfs/lock.go
+++ b/pkg/sentry/vfs/lock.go
@@ -17,8 +17,8 @@ package vfs
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
- "gvisor.dev/gvisor/pkg/syserror"
)
// FileLocks supports POSIX and BSD style locks, which correspond to fcntl(2)
@@ -47,9 +47,9 @@ func (fl *FileLocks) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerID i
// Return an appropriate error for the unsuccessful lock attempt, depending on
// whether this is a blocking or non-blocking operation.
if block == nil {
- return syserror.ErrWouldBlock
+ return linuxerr.ErrWouldBlock
}
- return syserror.ERESTARTSYS
+ return linuxerr.ERESTARTSYS
}
// UnlockBSD releases a BSD-style lock on the entire file.
@@ -69,9 +69,9 @@ func (fl *FileLocks) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPI
// Return an appropriate error for the unsuccessful lock attempt, depending on
// whether this is a blocking or non-blocking operation.
if block == nil {
- return syserror.ErrWouldBlock
+ return linuxerr.ErrWouldBlock
}
- return syserror.ERESTARTSYS
+ return linuxerr.ERESTARTSYS
}
// UnlockPOSIX releases a POSIX-style lock on a file region.
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index 4d6b59a26..05a416775 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -27,7 +27,6 @@ import (
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/syserror"
)
// A Mount is a replacement of a Dentry (Mount.key.point) from one Filesystem
@@ -225,7 +224,7 @@ func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Cr
vdDentry.mu.Unlock()
vfs.mountMu.Unlock()
vd.DecRef(ctx)
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
// vd might have been mounted over between vfs.GetDentryAt() and
// vfs.mountMu.Lock().
diff --git a/pkg/sentry/vfs/pathname.go b/pkg/sentry/vfs/pathname.go
index e4da15009..7cc68a157 100644
--- a/pkg/sentry/vfs/pathname.go
+++ b/pkg/sentry/vfs/pathname.go
@@ -16,9 +16,9 @@ package vfs
import (
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
var fspathBuilderPool = sync.Pool{
@@ -137,7 +137,7 @@ loop:
// Linux's sys_getcwd().
func (vfs *VirtualFilesystem) PathnameForGetcwd(ctx context.Context, vfsroot, vd VirtualDentry) (string, error) {
if vd.dentry.IsDead() {
- return "", syserror.ENOENT
+ return "", linuxerr.ENOENT
}
b := getFSPathBuilder()
diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go
index 4744514bd..953d31876 100644
--- a/pkg/sentry/vfs/permissions.go
+++ b/pkg/sentry/vfs/permissions.go
@@ -23,7 +23,6 @@ import (
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/limits"
- "gvisor.dev/gvisor/pkg/syserror"
)
// AccessTypes is a bitmask of Unix file permissions.
@@ -195,7 +194,7 @@ func CheckSetStat(ctx context.Context, creds *auth.Credentials, opts *SetStatOpt
return err
}
if limit < int64(stat.Size) {
- return syserror.ErrExceedsFileSizeLimit
+ return linuxerr.ErrExceedsFileSizeLimit
}
}
if stat.Mask&linux.STATX_MODE != 0 {
@@ -282,7 +281,7 @@ func CheckLimit(ctx context.Context, offset, size int64) (int64, error) {
return size, nil
}
if offset >= int64(fileSizeLimit) {
- return 0, syserror.ErrExceedsFileSizeLimit
+ return 0, linuxerr.ErrExceedsFileSizeLimit
}
remaining := int64(fileSizeLimit) - offset
if remaining < size {
diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go
index 6f58f33ce..40aff2927 100644
--- a/pkg/sentry/vfs/resolving_path.go
+++ b/pkg/sentry/vfs/resolving_path.go
@@ -23,7 +23,6 @@ import (
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
// ResolvingPath represents the state of an in-progress path resolution, shared
@@ -224,6 +223,12 @@ func (rp *ResolvingPath) Final() bool {
return rp.curPart == 0 && !rp.pit.NextOk()
}
+// Pit returns a copy of rp's current path iterator. Modifying the iterator
+// does not change rp.
+func (rp *ResolvingPath) Pit() fspath.Iterator {
+ return rp.pit
+}
+
// Component returns the current path component in the stream represented by
// rp.
//
@@ -331,7 +336,7 @@ func (rp *ResolvingPath) HandleSymlink(target string) error {
return linuxerr.ELOOP
}
if len(target) == 0 {
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
rp.symlinks++
targetPath := fspath.Parse(target)
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index eb3c60610..1b2a668c0 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -48,7 +48,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
)
// A VirtualFilesystem (VFS for short) combines Filesystems in trees of Mounts.
@@ -281,7 +280,7 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential
if newpop.Path.Absolute {
return linuxerr.EEXIST
}
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if newpop.FollowFinalSymlink {
oldVD.DecRef(ctx)
@@ -318,7 +317,7 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia
if pop.Path.Absolute {
return linuxerr.EEXIST
}
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if pop.FollowFinalSymlink {
ctx.Warningf("VirtualFilesystem.MkdirAt: file creation paths can't follow final symlink")
@@ -348,7 +347,7 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia
}
// MknodAt creates a file of the given mode at the given path. It returns an
-// error from the syserror package.
+// error from the linuxerr package.
func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MknodOptions) error {
if !pop.Path.Begin.Ok() {
// pop.Path should not be empty in operations that create/delete files.
@@ -356,7 +355,7 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia
if pop.Path.Absolute {
return linuxerr.EEXIST
}
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if pop.FollowFinalSymlink {
ctx.Warningf("VirtualFilesystem.MknodAt: file creation paths can't follow final symlink")
@@ -494,7 +493,7 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti
if oldpop.Path.Absolute {
return linuxerr.EBUSY
}
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if oldpop.FollowFinalSymlink {
ctx.Warningf("VirtualFilesystem.RenameAt: source path can't follow final symlink")
@@ -515,7 +514,7 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti
if newpop.Path.Absolute {
return linuxerr.EBUSY
}
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if newpop.FollowFinalSymlink {
oldParentVD.DecRef(ctx)
@@ -556,7 +555,7 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia
if pop.Path.Absolute {
return linuxerr.EBUSY
}
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if pop.FollowFinalSymlink {
ctx.Warningf("VirtualFilesystem.RmdirAt: file deletion paths can't follow final symlink")
@@ -639,7 +638,7 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent
if pop.Path.Absolute {
return linuxerr.EEXIST
}
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if pop.FollowFinalSymlink {
ctx.Warningf("VirtualFilesystem.SymlinkAt: file creation paths can't follow final symlink")
@@ -673,7 +672,7 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti
if pop.Path.Absolute {
return linuxerr.EBUSY
}
- return syserror.ENOENT
+ return linuxerr.ENOENT
}
if pop.FollowFinalSymlink {
ctx.Warningf("VirtualFilesystem.UnlinkAt: file deletion paths can't follow final symlink")
diff --git a/pkg/shim/runtimeoptions/runtimeoptions.go b/pkg/shim/runtimeoptions/runtimeoptions.go
index 072dd87f0..e76d73ea7 100644
--- a/pkg/shim/runtimeoptions/runtimeoptions.go
+++ b/pkg/shim/runtimeoptions/runtimeoptions.go
@@ -15,3 +15,10 @@
// Package runtimeoptions contains the runtimeoptions proto.
package runtimeoptions
+
+import proto "github.com/gogo/protobuf/proto"
+
+func init() {
+ // TODO(gvisor.dev/issue/6449): Upgrade runtimeoptions.proto after upgrading to containerd 1.5
+ proto.RegisterType((*Options)(nil), "runtimeoptions.v1.Options")
+}
diff --git a/pkg/sentry/sighandling/BUILD b/pkg/sighandling/BUILD
index 1790d57c9..72f10f982 100644
--- a/pkg/sentry/sighandling/BUILD
+++ b/pkg/sighandling/BUILD
@@ -8,7 +8,7 @@ go_library(
"sighandling.go",
"sighandling_unsafe.go",
],
- visibility = ["//pkg/sentry:internal"],
+ visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/sentry/sighandling/sighandling.go b/pkg/sighandling/sighandling.go
index bdaf8af29..bdaf8af29 100644
--- a/pkg/sentry/sighandling/sighandling.go
+++ b/pkg/sighandling/sighandling.go
diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sighandling/sighandling_unsafe.go
index 3fe5c6770..7deeda042 100644
--- a/pkg/sentry/sighandling/sighandling_unsafe.go
+++ b/pkg/sighandling/sighandling_unsafe.go
@@ -15,6 +15,7 @@
package sighandling
import (
+ "fmt"
"unsafe"
"golang.org/x/sys/unix"
@@ -37,3 +38,36 @@ func IgnoreChildStop() error {
return nil
}
+
+// ReplaceSignalHandler replaces the existing signal handler for the provided
+// signal with the function pointer at `handler`. This bypasses the Go runtime
+// signal handlers, and should only be used for low-level signal handlers where
+// use of signal.Notify is not appropriate.
+//
+// It stores the value of the previously set handler in previous.
+func ReplaceSignalHandler(sig unix.Signal, handler uintptr, previous *uintptr) error {
+ var sa linux.SigAction
+ const maskLen = 8
+
+ // Get the existing signal handler information, and save the current
+ // handler. Once we replace it, we will use this pointer to fall back to
+ // it when we receive other signals.
+ if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), 0, uintptr(unsafe.Pointer(&sa)), maskLen, 0, 0); e != 0 {
+ return e
+ }
+
+ // Fail if there isn't a previous handler.
+ if sa.Handler == 0 {
+ return fmt.Errorf("previous handler for signal %x isn't set", sig)
+ }
+
+ *previous = uintptr(sa.Handler)
+
+ // Install our own handler.
+ sa.Handler = uint64(handler)
+ if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 {
+ return e
+ }
+
+ return nil
+}
diff --git a/pkg/sync/atomicptr/generic_atomicptr_unsafe.go b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go
index 82b6df18c..7b9c2a4db 100644
--- a/pkg/sync/atomicptr/generic_atomicptr_unsafe.go
+++ b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go
@@ -37,6 +37,8 @@ func (p *AtomicPtr) loadPtr(v *Value) {
// Load returns the value set by the most recent Store. It returns nil if there
// has been no previous call to Store.
+//
+//go:nosplit
func (p *AtomicPtr) Load() *Value {
return (*Value)(atomic.LoadPointer(&p.ptr))
}
diff --git a/pkg/sync/atomicptrmap/BUILD b/pkg/sync/atomicptrmap/BUILD
index b0e218c79..1c0085d69 100644
--- a/pkg/sync/atomicptrmap/BUILD
+++ b/pkg/sync/atomicptrmap/BUILD
@@ -19,6 +19,7 @@ go_template(
"Key",
"Value",
],
+ visibility = ["//:sandbox"],
deps = [
"//pkg/gohacks",
"//pkg/sync",
diff --git a/pkg/syserr/BUILD b/pkg/syserr/BUILD
index ceee494fc..d8c4c9613 100644
--- a/pkg/syserr/BUILD
+++ b/pkg/syserr/BUILD
@@ -14,7 +14,7 @@ go_library(
"//pkg/abi/linux/errno",
"//pkg/errors",
"//pkg/errors/linuxerr",
- "//pkg/syserror",
+ "//pkg/safecopy",
"//pkg/tcpip",
"@org_golang_x_sys//unix:go_default_library",
],
diff --git a/pkg/syserr/syserr.go b/pkg/syserr/syserr.go
index 558240008..b679f3046 100644
--- a/pkg/syserr/syserr.go
+++ b/pkg/syserr/syserr.go
@@ -24,7 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux/errno"
"gvisor.dev/gvisor/pkg/errors"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
- "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/safecopy"
)
// Error represents an internal error.
@@ -52,12 +52,12 @@ func New(message string, linuxTranslation errno.Errno) *Error {
}
e := error(unix.Errno(err.errno))
- // syserror.ErrWouldBlock gets translated to linuxerr.EWOULDBLOCK and
+ // linuxerr.ErrWouldBlock gets translated to linuxerr.EWOULDBLOCK and
// enables proper blocking semantics. This should temporary address the
// class of blocking bugs that keep popping up with the current state of
// the error space.
if err.errno == linuxerr.EWOULDBLOCK.Errno() {
- e = syserror.ErrWouldBlock
+ e = linuxerr.ErrWouldBlock
}
linuxBackwardsTranslations[err.errno] = linuxBackwardsTranslation{err: e, ok: true}
@@ -279,16 +279,25 @@ func FromError(err error) *Error {
if err == nil {
return nil
}
- if errno, ok := err.(unix.Errno); ok {
- return FromHost(errno)
- }
- if linuxErr, ok := err.(*errors.Error); ok {
- return FromHost(unix.Errno(linuxErr.Errno()))
+ switch e := err.(type) {
+ case unix.Errno:
+ return FromHost(e)
+ case *errors.Error:
+ return FromHost(unix.Errno(e.Errno()))
+ case safecopy.SegvError, safecopy.BusError, safecopy.AlignmentError:
+ return FromHost(unix.EFAULT)
}
- if errno, ok := syserror.TranslateError(err); ok {
- return FromHost(errno)
+ msg := fmt.Sprintf("err: %s type: %T", err.Error(), err)
+ panic(msg)
+}
+
+// ConvertIntr converts the provided error code (err) to another one (intr) if
+// the first error corresponds to an interrupted operation.
+func ConvertIntr(err, intr error) error {
+ if err == linuxerr.ErrInterrupted {
+ return intr
}
- panic("unknown error: " + err.Error())
+ return err
}
diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go
deleted file mode 100644
index b24edb364..000000000
--- a/pkg/syserror/syserror.go
+++ /dev/null
@@ -1,180 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package syserror contains syscall error codes exported as error interface
-// instead of Errno. This allows for fast comparison and returns when the
-// comparand or return value is of type error because there is no need to
-// convert from Errno to an interface, i.e., runtime.convT2I isn't called.
-package syserror
-
-import (
- "errors"
-
- "golang.org/x/sys/unix"
-)
-
-// The following variables have the same meaning as their syscall equivalent.
-var (
- EIDRM = error(unix.EIDRM)
- EINTR = error(unix.EINTR)
- EIO = error(unix.EIO)
- EISDIR = error(unix.EISDIR)
- ENOENT = error(unix.ENOENT)
- ENOEXEC = error(unix.ENOEXEC)
- ENOMEM = error(unix.ENOMEM)
- ENOTSOCK = error(unix.ENOTSOCK)
- ENOSPC = error(unix.ENOSPC)
- ENOSYS = error(unix.ENOSYS)
-)
-
-var (
- // ErrWouldBlock is an internal error used to indicate that an operation
- // cannot be satisfied immediately, and should be retried at a later
- // time, possibly when the caller has received a notification that the
- // operation may be able to complete. It is used by implementations of
- // the kio.File interface.
- ErrWouldBlock = errors.New("request would block")
-
- // ErrInterrupted is returned if a request is interrupted before it can
- // complete.
- ErrInterrupted = errors.New("request was interrupted")
-
- // ErrExceedsFileSizeLimit is returned if a request would exceed the
- // file's size limit.
- ErrExceedsFileSizeLimit = errors.New("exceeds file size limit")
-)
-
-// errorMap is the map used to convert generic errors into errnos.
-var errorMap = map[error]unix.Errno{}
-
-// errorUnwrappers is an array of unwrap functions to extract typed errors.
-var errorUnwrappers = []func(error) (unix.Errno, bool){}
-
-// AddErrorTranslation allows modules to populate the error map by adding their
-// own translations during initialization. Returns if the error translation is
-// accepted or not. A pre-existing translation will not be overwritten by the
-// new translation.
-func AddErrorTranslation(from error, to unix.Errno) bool {
- if _, ok := errorMap[from]; ok {
- return false
- }
-
- errorMap[from] = to
- return true
-}
-
-// AddErrorUnwrapper registers an unwrap method that can extract a concrete error
-// from a typed, but not initialized, error.
-func AddErrorUnwrapper(unwrap func(e error) (unix.Errno, bool)) {
- errorUnwrappers = append(errorUnwrappers, unwrap)
-}
-
-// TranslateError translates errors to errnos, it will return false if
-// the error was not registered.
-func TranslateError(from error) (unix.Errno, bool) {
- if err, ok := errorMap[from]; ok {
- return err, true
- }
- // Try to unwrap the error if we couldn't match an error
- // exactly. This might mean that a package has its own
- // error type.
- for _, unwrap := range errorUnwrappers {
- if err, ok := unwrap(from); ok {
- return err, true
- }
- }
- return 0, false
-}
-
-// ConvertIntr converts the provided error code (err) to another one (intr) if
-// the first error corresponds to an interrupted operation.
-func ConvertIntr(err, intr error) error {
- if err == ErrInterrupted {
- return intr
- }
- return err
-}
-
-// SyscallRestartErrno represents a ERESTART* errno defined in the Linux's kernel
-// include/linux/errno.h. These errnos are never returned to userspace
-// directly, but are used to communicate the expected behavior of an
-// interrupted syscall from the syscall to signal handling.
-type SyscallRestartErrno int
-
-// These numeric values are significant because ptrace syscall exit tracing can
-// observe them.
-//
-// For all of the following errnos, if the syscall is not interrupted by a
-// signal delivered to a user handler, the syscall is restarted.
-const (
- // ERESTARTSYS is returned by an interrupted syscall to indicate that it
- // should be converted to EINTR if interrupted by a signal delivered to a
- // user handler without SA_RESTART set, and restarted otherwise.
- ERESTARTSYS = SyscallRestartErrno(512)
-
- // ERESTARTNOINTR is returned by an interrupted syscall to indicate that it
- // should always be restarted.
- ERESTARTNOINTR = SyscallRestartErrno(513)
-
- // ERESTARTNOHAND is returned by an interrupted syscall to indicate that it
- // should be converted to EINTR if interrupted by a signal delivered to a
- // user handler, and restarted otherwise.
- ERESTARTNOHAND = SyscallRestartErrno(514)
-
- // ERESTART_RESTARTBLOCK is returned by an interrupted syscall to indicate
- // that it should be restarted using a custom function. The interrupted
- // syscall must register a custom restart function by calling
- // Task.SetRestartSyscallFn.
- ERESTART_RESTARTBLOCK = SyscallRestartErrno(516)
-)
-
-// Error implements error.Error.
-func (e SyscallRestartErrno) Error() string {
- // Descriptions are borrowed from strace.
- switch e {
- case ERESTARTSYS:
- return "to be restarted if SA_RESTART is set"
- case ERESTARTNOINTR:
- return "to be restarted"
- case ERESTARTNOHAND:
- return "to be restarted if no handler"
- case ERESTART_RESTARTBLOCK:
- return "interrupted by signal"
- default:
- return "(unknown interrupt error)"
- }
-}
-
-// SyscallRestartErrnoFromReturn returns the SyscallRestartErrno represented by
-// rv, the value in a syscall return register.
-func SyscallRestartErrnoFromReturn(rv uintptr) (SyscallRestartErrno, bool) {
- switch int(rv) {
- case -int(ERESTARTSYS):
- return ERESTARTSYS, true
- case -int(ERESTARTNOINTR):
- return ERESTARTNOINTR, true
- case -int(ERESTARTNOHAND):
- return ERESTARTNOHAND, true
- case -int(ERESTART_RESTARTBLOCK):
- return ERESTART_RESTARTBLOCK, true
- default:
- return 0, false
- }
-}
-
-func init() {
- AddErrorTranslation(ErrWouldBlock, unix.EWOULDBLOCK)
- AddErrorTranslation(ErrInterrupted, unix.EINTR)
- AddErrorTranslation(ErrExceedsFileSizeLimit, unix.EFBIG)
-}
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index f00cfd0f5..b98de54c5 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -25,6 +25,7 @@ go_library(
"stdclock.go",
"stdclock_state.go",
"tcpip.go",
+ "tcpip_state.go",
"timer.go",
],
visibility = ["//visibility:public"],
@@ -69,7 +70,6 @@ deps_test(
"//pkg/tcpip/header",
"//pkg/tcpip/link/fdbased",
"//pkg/tcpip/link/loopback",
- "//pkg/tcpip/link/packetsocket",
"//pkg/tcpip/link/qdisc/fifo",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/arp",
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 48b24692b..c8460e63c 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -137,7 +137,13 @@ func TestCloseReader(t *testing.T) {
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
if e != nil {
@@ -190,7 +196,13 @@ func TestCloseReaderWithForwarder(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
done := make(chan struct{})
@@ -244,7 +256,13 @@ func TestCloseRead(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
@@ -288,7 +306,13 @@ func TestCloseWrite(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
@@ -349,10 +373,22 @@ func TestUDPForwarder(t *testing.T) {
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip1)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr1, err)
+ }
ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
addr2 := tcpip.FullAddress{NICID, ip2, 11311}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr2, err)
+ }
done := make(chan struct{})
fwd := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
@@ -410,7 +446,13 @@ func TestDeadlineChange(t *testing.T) {
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
if e != nil {
@@ -465,10 +507,22 @@ func TestPacketConnTransfer(t *testing.T) {
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip1)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr1, err)
+ }
ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
addr2 := tcpip.FullAddress{NICID, ip2, 11311}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr2, err)
+ }
c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber)
if err != nil {
@@ -521,7 +575,13 @@ func TestConnectedPacketConnTransfer(t *testing.T) {
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber)
if err != nil {
@@ -565,24 +625,30 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, nil, nil, fmt.Errorf("AddProtocolAddress(%d, %+v, {}): %w", NICID, protocolAddr, err)
+ }
l, err := ListenTCP(s, addr, ipv4.ProtocolNumber)
if err != nil {
- return nil, nil, nil, fmt.Errorf("NewListener: %v", err)
+ return nil, nil, nil, fmt.Errorf("NewListener: %w", err)
}
c1, err = DialTCP(s, addr, ipv4.ProtocolNumber)
if err != nil {
l.Close()
- return nil, nil, nil, fmt.Errorf("DialTCP: %v", err)
+ return nil, nil, nil, fmt.Errorf("DialTCP: %w", err)
}
c2, err = l.Accept()
if err != nil {
l.Close()
c1.Close()
- return nil, nil, nil, fmt.Errorf("l.Accept: %v", err)
+ return nil, nil, nil, fmt.Errorf("l.Accept: %w", err)
}
stop = func() {
@@ -594,7 +660,7 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
if err := l.Close(); err != nil {
stop()
- return nil, nil, nil, fmt.Errorf("l.Close(): %v", err)
+ return nil, nil, nil, fmt.Errorf("l.Close(): %w", err)
}
return c1, c2, stop, nil
@@ -681,7 +747,13 @@ func TestDialContextTCPCanceled(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
@@ -703,7 +775,13 @@ func TestDialContextTCPTimeout(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
time.Sleep(time.Second)
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index e0dfe5813..24c2c3e6b 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -324,6 +324,19 @@ func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker {
}
}
+// ReceiveIPv6PacketInfo creates a checker that checks the IPv6PacketInfo field
+// in ControlMessages.
+func ReceiveIPv6PacketInfo(want tcpip.IPv6PacketInfo) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasIPv6PacketInfo {
+ t.Errorf("got cm.HasIPv6PacketInfo = %t, want = true", cm.HasIPv6PacketInfo)
+ } else if diff := cmp.Diff(want, cm.IPv6PacketInfo); diff != "" {
+ t.Errorf("IPv6PacketInfo mismatch (-want +got):\n%s", diff)
+ }
+ }
+}
+
// ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress
// field in ControlMessages.
func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker {
@@ -729,7 +742,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp
return
}
l := int(opts[i+1])
- if i < 2 || i+l > limit {
+ if l < 2 || i+l > limit {
return
}
i += l
diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go
index 95ade0e5c..1f18213e5 100644
--- a/pkg/tcpip/header/eth.go
+++ b/pkg/tcpip/header/eth.go
@@ -49,9 +49,9 @@ const (
// EthernetAddressSize is the size, in bytes, of an ethernet address.
EthernetAddressSize = 6
- // unspecifiedEthernetAddress is the unspecified ethernet address
+ // UnspecifiedEthernetAddress is the unspecified ethernet address
// (all bits set to 0).
- unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")
+ UnspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")
// EthernetBroadcastAddress is an ethernet address that addresses every node
// on a local link.
@@ -134,7 +134,7 @@ func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool {
return false
}
- if addr == unspecifiedEthernetAddress {
+ if addr == UnspecifiedEthernetAddress {
return false
}
diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go
index bf9ccbf1a..adc04e855 100644
--- a/pkg/tcpip/header/eth_test.go
+++ b/pkg/tcpip/header/eth_test.go
@@ -44,7 +44,7 @@ func TestIsValidUnicastEthernetAddress(t *testing.T) {
},
{
"Unspecified",
- unspecifiedEthernetAddress,
+ UnspecifiedEthernetAddress,
false,
},
{
@@ -91,7 +91,7 @@ func TestIsMulticastEthernetAddress(t *testing.T) {
},
{
"Unspecified",
- unspecifiedEthernetAddress,
+ UnspecifiedEthernetAddress,
false,
},
{
diff --git a/pkg/tcpip/internal/tcp/BUILD b/pkg/tcpip/internal/tcp/BUILD
new file mode 100644
index 000000000..9ae258a0b
--- /dev/null
+++ b/pkg/tcpip/internal/tcp/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "tcp",
+ srcs = ["tcp.go"],
+ visibility = ["//pkg/tcpip:__subpackages__"],
+ deps = [
+ "//pkg/tcpip",
+ ],
+)
diff --git a/pkg/tcpip/internal/tcp/tcp.go b/pkg/tcpip/internal/tcp/tcp.go
new file mode 100644
index 000000000..0616d368c
--- /dev/null
+++ b/pkg/tcpip/internal/tcp/tcp.go
@@ -0,0 +1,48 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tcp contains internal type definitions that are not expected to be
+// used by anyone else outside pkg/tcpip.
+package tcp
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// TSOffset is an offset applied to the value of the TSVal field in the TCP
+// Timestamp option.
+//
+// +stateify savable
+type TSOffset struct {
+ milliseconds uint32
+}
+
+// NewTSOffset creates a new TSOffset from milliseconds.
+func NewTSOffset(milliseconds uint32) TSOffset {
+ return TSOffset{
+ milliseconds: milliseconds,
+ }
+}
+
+// TSVal applies the offset to now and returns the timestamp in milliseconds.
+func (offset TSOffset) TSVal(now tcpip.MonotonicTime) uint32 {
+ return uint32(now.Sub(tcpip.MonotonicTime{}).Milliseconds()) + offset.milliseconds
+}
+
+// Elapsed calculates the elapsed time given now and the echoed back timestamp.
+func (offset TSOffset) Elapsed(now tcpip.MonotonicTime, tsEcr uint32) time.Duration {
+ return time.Duration(offset.TSVal(now)-tsEcr) * time.Millisecond
+}
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index f26c857eb..658557d62 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -28,7 +28,9 @@ import (
// PacketInfo holds all the information about an outbound packet.
type PacketInfo struct {
- Pkt *stack.PacketBuffer
+ Pkt *stack.PacketBuffer
+
+ // TODO(https://gvisor.dev/issue/6537): Remove these fields.
Proto tcpip.NetworkProtocolNumber
Route stack.RouteInfo
}
@@ -244,7 +246,10 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol
Route: r,
}
- e.q.Write(p)
+ // Write returns false if the queue is full. A full queue is not an error
+ // from the perspective of a LinkEndpoint so we ignore Write's return
+ // value and always return nil from this method.
+ _ = e.q.Write(p)
return nil
}
@@ -290,3 +295,18 @@ func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
// AddHeader implements stack.LinkEndpoint.AddHeader.
func (*Endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *Endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ p := PacketInfo{
+ Pkt: pkt,
+ Proto: pkt.NetworkProtocolNumber,
+ }
+
+ // Write returns false if the queue is full. A full queue is not an error
+ // from the perspective of a LinkEndpoint so we ignore Write's return
+ // value and always return nil from this method.
+ _ = e.q.Write(p)
+
+ return nil
+}
diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go
index b427c6170..8211a2031 100644
--- a/pkg/tcpip/link/ethernet/ethernet.go
+++ b/pkg/tcpip/link/ethernet/ethernet.go
@@ -42,6 +42,14 @@ type Endpoint struct {
nested.Endpoint
}
+// LinkAddress implements stack.LinkEndpoint.
+func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
+ if l := e.Endpoint.LinkAddress(); len(l) != 0 {
+ return l
+ }
+ return header.UnspecifiedEthernetAddress
+}
+
// DeliverNetworkPacket implements stack.NetworkDispatcher.
func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
@@ -57,18 +65,22 @@ func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkP
// Capabilities implements stack.LinkEndpoint.
func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return stack.CapabilityResolutionRequired | e.Endpoint.Capabilities()
+ c := e.Endpoint.Capabilities()
+ if c&stack.CapabilityLoopback == 0 {
+ c |= stack.CapabilityResolutionRequired
+ }
+ return c
}
// WritePacket implements stack.LinkEndpoint.
func (e *Endpoint) WritePacket(r stack.RouteInfo, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt)
+ e.AddHeader(e.LinkAddress(), r.RemoteLinkAddress, proto, pkt)
return e.Endpoint.WritePacket(r, proto, pkt)
}
// WritePackets implements stack.LinkEndpoint.
func (e *Endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- linkAddr := e.Endpoint.LinkAddress()
+ linkAddr := e.LinkAddress()
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt)
@@ -83,7 +95,10 @@ func (e *Endpoint) MaxHeaderLength() uint16 {
}
// ARPHardwareType implements stack.LinkEndpoint.
-func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
+func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
+ if a := e.Endpoint.ARPHardwareType(); a != header.ARPHardwareNone {
+ return a
+ }
return header.ARPHardwareEther
}
@@ -97,3 +112,8 @@ func (*Endpoint) AddHeader(local, remote tcpip.LinkAddress, proto tcpip.NetworkP
}
eth.Encode(&fields)
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *Endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ return e.Endpoint.WriteRawPacket(pkt)
+}
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 48356c343..058242f96 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -505,6 +505,9 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
}
}
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
+
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 7012d8829..ca1f9c08d 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -76,19 +76,8 @@ func (*endpoint) Wait() {}
// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound
// packets to the network-layer dispatcher.
-func (e *endpoint) WritePacket(_ stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- // Construct data as the unparsed portion for the loopback packet.
- data := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
-
- // Because we're immediately turning around and writing the packet back
- // to the rx path, we intentionally don't preserve the remote and local
- // link addresses from the stack.Route we're passed.
- newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: data,
- })
- e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, newPkt)
-
- return nil
+func (e *endpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ return e.WriteRawPacket(pkt)
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
@@ -103,3 +92,19 @@ func (*endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ // Construct data as the unparsed portion for the loopback packet.
+ data := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
+
+ // Because we're immediately turning around and writing the packet back
+ // to the rx path, we intentionally don't preserve the remote and local
+ // link addresses from the stack.Route we're passed.
+ newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: data,
+ })
+ e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, pkt.NetworkProtocolNumber, newPkt)
+
+ return nil
+}
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index 3e2a1aa94..844f5959b 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -131,6 +131,11 @@ func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType {
func (*InjectableEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
}
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*InjectableEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint.
func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint {
return &InjectableEndpoint{
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
index 3e816b0c7..83a6c1cc8 100644
--- a/pkg/tcpip/link/nested/nested.go
+++ b/pkg/tcpip/link/nested/nested.go
@@ -60,16 +60,6 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
}
}
-// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
-func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- e.mu.RLock()
- d := e.dispatcher
- e.mu.RUnlock()
- if d != nil {
- d.DeliverOutboundPacket(remote, local, protocol, pkt)
- }
-}
-
// Attach implements stack.LinkEndpoint.
func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.mu.Lock()
@@ -152,3 +142,8 @@ func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.child.AddHeader(local, remote, protocol, pkt)
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *Endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ return e.child.WriteRawPacket(pkt)
+}
diff --git a/pkg/tcpip/link/packetsocket/BUILD b/pkg/tcpip/link/packetsocket/BUILD
deleted file mode 100644
index 6fff160ce..000000000
--- a/pkg/tcpip/link/packetsocket/BUILD
+++ /dev/null
@@ -1,14 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "packetsocket",
- srcs = ["endpoint.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/link/nested",
- "//pkg/tcpip/stack",
- ],
-)
diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go
deleted file mode 100644
index e01837e2d..000000000
--- a/pkg/tcpip/link/packetsocket/endpoint.go
+++ /dev/null
@@ -1,50 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package packetsocket provides a link layer endpoint that provides the ability
-// to loop outbound packets to any AF_PACKET sockets that may be interested in
-// the outgoing packet.
-package packetsocket
-
-import (
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/link/nested"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
-)
-
-type endpoint struct {
- nested.Endpoint
-}
-
-// New creates a new packetsocket LinkEndpoint.
-func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
- e := &endpoint{}
- e.Endpoint.Init(lower, e)
- return e
-}
-
-// WritePacket implements stack.LinkEndpoint.WritePacket.
-func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt)
- return e.Endpoint.WritePacket(r, protocol, pkt)
-}
-
-// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt)
- }
-
- return e.Endpoint.WritePackets(r, pkts, proto)
-}
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
index 5030b6ba1..3ed0aa3fe 100644
--- a/pkg/tcpip/link/pipe/pipe.go
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -121,3 +121,6 @@ func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
// AddHeader implements stack.LinkEndpoint.
func (*Endpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*Endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index 40bd5560b..b41e3e2fa 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -108,11 +108,6 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt)
}
-// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
-func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt)
-}
-
// Attach implements stack.LinkEndpoint.Attach.
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
// nil means the NIC is being removed.
@@ -228,3 +223,8 @@ func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.lower.AddHeader(local, remote, protocol, pkt)
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ return e.lower.WriteRawPacket(pkt)
+}
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
index 298bad55d..f2c230720 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
+++ b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
@@ -27,7 +27,7 @@ TEXT ·BlockingPoll(SB),NOSPLIT,$0-40
MOVQ $0x0, R10 // sigmask parameter which isn't used here
MOVQ $0x10f, AX // SYS_PPOLL
SYSCALL
- CMPQ AX, $0xfffffffffffff001
+ CMPQ AX, $0xfffffffffffff002
JLS ok
MOVQ $-1, n+24(FP)
NEGQ AX
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
index b62888b93..8807586c7 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
+++ b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
@@ -27,7 +27,7 @@ TEXT ·BlockingPoll(SB),NOSPLIT,$0-40
MOVD $0x0, R3 // sigmask parameter which isn't used here
MOVD $0x49, R8 // SYS_PPOLL
SVC
- CMP $0xfffffffffffff001, R0
+ CMP $0xfffffffffffff002, R0
BLS ok
MOVD $-1, R1
MOVD R1, n+24(FP)
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index e76fc55b6..e53789d92 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -152,10 +152,22 @@ type PollEvent struct {
// no data is available, it will block in a poll() syscall until the file
// descriptor becomes readable.
func BlockingRead(fd int, b []byte) (int, tcpip.Error) {
+ n, err := BlockingReadUntranslated(fd, b)
+ if err != 0 {
+ return n, TranslateErrno(err)
+ }
+ return n, nil
+}
+
+// BlockingReadUntranslated reads from a file descriptor that is set up as
+// non-blocking. If no data is available, it will block in a poll() syscall
+// until the file descriptor becomes readable. It returns the raw unix.Errno
+// value returned by the underlying syscalls.
+func BlockingReadUntranslated(fd int, b []byte) (int, unix.Errno) {
for {
n, _, e := unix.RawSyscall(unix.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)))
if e == 0 {
- return int(n), nil
+ return int(n), 0
}
event := PollEvent{
@@ -165,7 +177,7 @@ func BlockingRead(fd int, b []byte) (int, tcpip.Error) {
_, e = BlockingPoll(&event, 1, nil)
if e != 0 && e != unix.EINTR {
- return 0, TranslateErrno(e)
+ return 0, e
}
}
}
@@ -181,7 +193,9 @@ func BlockingReadvUntilStopped(efd int, fd int, iovecs []unix.Iovec) (int, tcpip
if e == 0 {
return int(n), nil
}
-
+ if e != 0 && e != unix.EWOULDBLOCK {
+ return 0, TranslateErrno(e)
+ }
stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN)
if stopped {
return -1, nil
@@ -204,6 +218,10 @@ func BlockingRecvMMsgUntilStopped(efd int, fd int, msgHdrs []MMsgHdr) (int, tcpi
return int(n), nil
}
+ if e != 0 && e != unix.EWOULDBLOCK {
+ return 0, TranslateErrno(e)
+ }
+
stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN)
if stopped {
return -1, nil
@@ -228,5 +246,13 @@ func BlockingPollUntilStopped(efd int, fd int, events int16) (bool, unix.Errno)
},
}
_, errno := BlockingPoll(&pevents[0], len(pevents), nil)
+ if errno != 0 {
+ return pevents[0].Revents&unix.POLLIN != 0, errno
+ }
+
+ if pevents[1].Revents&unix.POLLHUP != 0 || pevents[1].Revents&unix.POLLERR != 0 {
+ errno = unix.ECONNRESET
+ }
+
return pevents[0].Revents&unix.POLLIN != 0, errno
}
diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD
index 4215ee852..f8076d83c 100644
--- a/pkg/tcpip/link/sharedmem/BUILD
+++ b/pkg/tcpip/link/sharedmem/BUILD
@@ -5,19 +5,26 @@ package(licenses = ["notice"])
go_library(
name = "sharedmem",
srcs = [
+ "queuepair.go",
"rx.go",
+ "server_rx.go",
+ "server_tx.go",
"sharedmem.go",
+ "sharedmem_server.go",
"sharedmem_unsafe.go",
"tx.go",
],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/cleanup",
+ "//pkg/eventfd",
"//pkg/log",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/link/rawfile",
+ "//pkg/tcpip/link/sharedmem/pipe",
"//pkg/tcpip/link/sharedmem/queue",
"//pkg/tcpip/stack",
"@org_golang_x_sys//unix:go_default_library",
@@ -26,9 +33,7 @@ go_library(
go_test(
name = "sharedmem_test",
- srcs = [
- "sharedmem_test.go",
- ],
+ srcs = ["sharedmem_test.go"],
library = ":sharedmem",
deps = [
"//pkg/sync",
@@ -41,3 +46,22 @@ go_test(
"@org_golang_x_sys//unix:go_default_library",
],
)
+
+go_test(
+ name = "sharedmem_server_test",
+ size = "small",
+ srcs = ["sharedmem_server_test.go"],
+ deps = [
+ ":sharedmem",
+ "//pkg/tcpip",
+ "//pkg/tcpip/adapters/gonet",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/link/sharedmem/queue/rx.go b/pkg/tcpip/link/sharedmem/queue/rx.go
index 696e6c9e5..a78826ebc 100644
--- a/pkg/tcpip/link/sharedmem/queue/rx.go
+++ b/pkg/tcpip/link/sharedmem/queue/rx.go
@@ -119,7 +119,6 @@ func (r *Rx) PostBuffers(buffers []RxBuffer) bool {
}
r.tx.Flush()
-
return true
}
@@ -131,7 +130,6 @@ func (r *Rx) PostBuffers(buffers []RxBuffer) bool {
func (r *Rx) Dequeue(bufs []RxBuffer) ([]RxBuffer, uint32) {
for {
outBufs := bufs
-
// Pull the next descriptor from the rx pipe.
b := r.rx.Pull()
if b == nil {
diff --git a/pkg/tcpip/link/sharedmem/queuepair.go b/pkg/tcpip/link/sharedmem/queuepair.go
new file mode 100644
index 000000000..b12647fdd
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/queuepair.go
@@ -0,0 +1,199 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "fmt"
+ "io/ioutil"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/eventfd"
+)
+
+const (
+ // defaultQueueDataSize is the size of the shared memory data region that
+ // holds the scatter/gather buffers.
+ defaultQueueDataSize = 1 << 20 // 1MiB
+
+ // defaultQueuePipeSize is the size of the pipe that holds the packet descriptors.
+ //
+ // Assuming each packet data is approximately 1280 bytes (IPv6 Minimum MTU)
+ // then we can hold approximately 1024*1024/1280 ~ 819 packets in the data
+ // area. Which means the pipe needs to be big enough to hold 819
+ // descriptors.
+ //
+ // Each descriptor is approximately 8 (slot descriptor in pipe) +
+ // 16 (packet descriptor) + 12 (for buffer descriptor) assuming each packet is
+ // stored in exactly 1 buffer descriptor (see queue/tx.go and pipe/tx.go.)
+ //
+ // Which means we need approximately 36*819 ~ 29 KiB to store all packet
+ // descriptors. We could go with a 32 KiB pipe but to give it some slack in
+ // how the upper layer may make use of the scatter gather buffers we double
+ // this to hold enough descriptors.
+ defaultQueuePipeSize = 64 << 10 // 64KiB
+
+ // defaultSharedDataSize is the size of the sharedData region used to
+ // enable/disable notifications.
+ defaultSharedDataSize = 4 << 10 // 4KiB
+)
+
+// A QueuePair represents a pair of TX/RX queues.
+type QueuePair struct {
+ // txCfg is the QueueConfig to be used for transmit queue.
+ txCfg QueueConfig
+
+ // rxCfg is the QueueConfig to be used for receive queue.
+ rxCfg QueueConfig
+}
+
+// NewQueuePair creates a shared memory QueuePair.
+func NewQueuePair() (*QueuePair, error) {
+ txCfg, err := createQueueFDs(queueSizes{
+ dataSize: defaultQueueDataSize,
+ txPipeSize: defaultQueuePipeSize,
+ rxPipeSize: defaultQueuePipeSize,
+ sharedDataSize: defaultSharedDataSize,
+ })
+
+ if err != nil {
+ return nil, fmt.Errorf("failed to create tx queue: %s", err)
+ }
+
+ rxCfg, err := createQueueFDs(queueSizes{
+ dataSize: defaultQueueDataSize,
+ txPipeSize: defaultQueuePipeSize,
+ rxPipeSize: defaultQueuePipeSize,
+ sharedDataSize: defaultSharedDataSize,
+ })
+
+ if err != nil {
+ closeFDs(txCfg)
+ return nil, fmt.Errorf("failed to create rx queue: %s", err)
+ }
+
+ return &QueuePair{
+ txCfg: txCfg,
+ rxCfg: rxCfg,
+ }, nil
+}
+
+// Close closes underlying tx/rx queue fds.
+func (q *QueuePair) Close() {
+ closeFDs(q.txCfg)
+ closeFDs(q.rxCfg)
+}
+
+// TXQueueConfig returns the QueueConfig for the receive queue.
+func (q *QueuePair) TXQueueConfig() QueueConfig {
+ return q.txCfg
+}
+
+// RXQueueConfig returns the QueueConfig for the transmit queue.
+func (q *QueuePair) RXQueueConfig() QueueConfig {
+ return q.rxCfg
+}
+
+type queueSizes struct {
+ dataSize int64
+ txPipeSize int64
+ rxPipeSize int64
+ sharedDataSize int64
+}
+
+func createQueueFDs(s queueSizes) (QueueConfig, error) {
+ success := false
+ var eventFD eventfd.Eventfd
+ var dataFD, txPipeFD, rxPipeFD, sharedDataFD int
+ defer func() {
+ if success {
+ return
+ }
+ closeFDs(QueueConfig{
+ EventFD: eventFD,
+ DataFD: dataFD,
+ TxPipeFD: txPipeFD,
+ RxPipeFD: rxPipeFD,
+ SharedDataFD: sharedDataFD,
+ })
+ }()
+ eventFD, err := eventfd.Create()
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("eventfd failed: %v", err)
+ }
+ dataFD, err = createFile(s.dataSize, false)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create dataFD: %s", err)
+ }
+ txPipeFD, err = createFile(s.txPipeSize, true)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create txPipeFD: %s", err)
+ }
+ rxPipeFD, err = createFile(s.rxPipeSize, true)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create rxPipeFD: %s", err)
+ }
+ sharedDataFD, err = createFile(s.sharedDataSize, false)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create sharedDataFD: %s", err)
+ }
+ success = true
+ return QueueConfig{
+ EventFD: eventFD,
+ DataFD: dataFD,
+ TxPipeFD: txPipeFD,
+ RxPipeFD: rxPipeFD,
+ SharedDataFD: sharedDataFD,
+ }, nil
+}
+
+func createFile(size int64, initQueue bool) (fd int, err error) {
+ const tmpDir = "/dev/shm/"
+ f, err := ioutil.TempFile(tmpDir, "sharedmem_test")
+ if err != nil {
+ return -1, fmt.Errorf("TempFile failed: %v", err)
+ }
+ defer f.Close()
+ unix.Unlink(f.Name())
+
+ if initQueue {
+ // Write the "slot-free" flag in the initial queue.
+ if _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0); err != nil {
+ return -1, fmt.Errorf("WriteAt failed: %v", err)
+ }
+ }
+
+ fd, err = unix.Dup(int(f.Fd()))
+ if err != nil {
+ return -1, fmt.Errorf("unix.Dup(%d) failed: %v", f.Fd(), err)
+ }
+
+ if err := unix.Ftruncate(fd, size); err != nil {
+ unix.Close(fd)
+ return -1, fmt.Errorf("ftruncate(%d, %d) failed: %v", fd, size, err)
+ }
+
+ return fd, nil
+}
+
+func closeFDs(c QueueConfig) {
+ unix.Close(c.DataFD)
+ c.EventFD.Close()
+ unix.Close(c.TxPipeFD)
+ unix.Close(c.RxPipeFD)
+ unix.Close(c.SharedDataFD)
+}
diff --git a/pkg/tcpip/link/sharedmem/rx.go b/pkg/tcpip/link/sharedmem/rx.go
index e882a128c..87747dcc7 100644
--- a/pkg/tcpip/link/sharedmem/rx.go
+++ b/pkg/tcpip/link/sharedmem/rx.go
@@ -21,7 +21,7 @@ import (
"sync/atomic"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.dev/gvisor/pkg/eventfd"
"gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
)
@@ -30,7 +30,7 @@ type rx struct {
data []byte
sharedData []byte
q queue.Rx
- eventFD int
+ eventFD eventfd.Eventfd
}
// init initializes all state needed by the rx queue based on the information
@@ -68,7 +68,7 @@ func (r *rx) init(mtu uint32, c *QueueConfig) error {
// Duplicate the eventFD so that caller can close it but we can still
// use it.
- efd, err := unix.Dup(c.EventFD)
+ efd, err := c.EventFD.Dup()
if err != nil {
unix.Munmap(txPipe)
unix.Munmap(rxPipe)
@@ -77,16 +77,6 @@ func (r *rx) init(mtu uint32, c *QueueConfig) error {
return err
}
- // Set the eventfd as non-blocking.
- if err := unix.SetNonblock(efd, true); err != nil {
- unix.Munmap(txPipe)
- unix.Munmap(rxPipe)
- unix.Munmap(data)
- unix.Munmap(sharedData)
- unix.Close(efd)
- return err
- }
-
// Initialize state based on buffers.
r.q.Init(txPipe, rxPipe, sharedDataPointer(sharedData))
r.data = data
@@ -105,7 +95,13 @@ func (r *rx) cleanup() {
unix.Munmap(r.data)
unix.Munmap(r.sharedData)
- unix.Close(r.eventFD)
+ r.eventFD.Close()
+}
+
+// notify writes to the tx.eventFD to indicate to the peer that there is data to
+// be read.
+func (r *rx) notify() {
+ r.eventFD.Notify()
}
// postAndReceive posts the provided buffers (if any), and then tries to read
@@ -122,8 +118,7 @@ func (r *rx) postAndReceive(b []queue.RxBuffer, stopRequested *uint32) ([]queue.
if len(b) != 0 && !r.q.PostBuffers(b) {
r.q.EnableNotification()
for !r.q.PostBuffers(b) {
- var tmp [8]byte
- rawfile.BlockingRead(r.eventFD, tmp[:])
+ r.eventFD.Wait()
if atomic.LoadUint32(stopRequested) != 0 {
r.q.DisableNotification()
return nil, 0
@@ -147,8 +142,7 @@ func (r *rx) postAndReceive(b []queue.RxBuffer, stopRequested *uint32) ([]queue.
}
// Wait for notification.
- var tmp [8]byte
- rawfile.BlockingRead(r.eventFD, tmp[:])
+ r.eventFD.Wait()
if atomic.LoadUint32(stopRequested) != 0 {
r.q.DisableNotification()
return nil, 0
diff --git a/pkg/tcpip/link/sharedmem/server_rx.go b/pkg/tcpip/link/sharedmem/server_rx.go
new file mode 100644
index 000000000..6ea21ffd1
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/server_rx.go
@@ -0,0 +1,142 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/eventfd"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
+)
+
+type serverRx struct {
+ // packetPipe represents the receive end of the pipe that carries the packet
+ // descriptors sent by the client.
+ packetPipe pipe.Rx
+
+ // completionPipe represents the transmit end of the pipe that will carry
+ // completion notifications from the server to the client.
+ completionPipe pipe.Tx
+
+ // data represents the buffer area where the packet payload is held.
+ data []byte
+
+ // eventFD is used to notify the peer when transmission is completed.
+ eventFD eventfd.Eventfd
+
+ // sharedData the memory region to use to enable/disable notifications.
+ sharedData []byte
+}
+
+// init initializes all state needed by the serverTx queue based on the
+// information provided.
+//
+// The caller always retains ownership of all file descriptors passed in. The
+// queue implementation will duplicate any that it may need in the future.
+func (s *serverRx) init(c *QueueConfig) error {
+ // Map in all buffers.
+ packetPipeMem, err := getBuffer(c.TxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu := cleanup.Make(func() { unix.Munmap(packetPipeMem) })
+ defer cu.Clean()
+
+ completionPipeMem, err := getBuffer(c.RxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(completionPipeMem) })
+
+ data, err := getBuffer(c.DataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(data) })
+
+ sharedData, err := getBuffer(c.SharedDataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(sharedData) })
+
+ // Duplicate the eventFD so that caller can close it but we can still
+ // use it.
+ efd, err := c.EventFD.Dup()
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { efd.Close() })
+
+ s.packetPipe.Init(packetPipeMem)
+ s.completionPipe.Init(completionPipeMem)
+ s.data = data
+ s.eventFD = efd
+ s.sharedData = sharedData
+
+ cu.Release()
+ return nil
+}
+
+func (s *serverRx) cleanup() {
+ unix.Munmap(s.packetPipe.Bytes())
+ unix.Munmap(s.completionPipe.Bytes())
+ unix.Munmap(s.data)
+ unix.Munmap(s.sharedData)
+ s.eventFD.Close()
+}
+
+// completionNotificationSize is size in bytes of a completion notification sent
+// on the completion queue after a transmitted packet has been handled.
+const completionNotificationSize = 8
+
+// receive receives a single packet from the packetPipe.
+func (s *serverRx) receive() []byte {
+ desc := s.packetPipe.Pull()
+ if desc == nil {
+ return nil
+ }
+
+ pktInfo := queue.DecodeTxPacketHeader(desc)
+ contents := make([]byte, 0, pktInfo.Size)
+ toCopy := pktInfo.Size
+ for i := 0; i < pktInfo.BufferCount; i++ {
+ txBuf := queue.DecodeTxBufferHeader(desc, i)
+ if txBuf.Size <= toCopy {
+ contents = append(contents, s.data[txBuf.Offset:][:txBuf.Size]...)
+ toCopy -= txBuf.Size
+ continue
+ }
+ contents = append(contents, s.data[txBuf.Offset:][:toCopy]...)
+ break
+ }
+
+ // Flush to let peer know that slots queued for transmission have been handled
+ // and its free to reuse the slots.
+ s.packetPipe.Flush()
+ // Encode packet completion.
+ b := s.completionPipe.Push(completionNotificationSize)
+ queue.EncodeTxCompletion(b, pktInfo.ID)
+ s.completionPipe.Flush()
+ return contents
+}
+
+func (s *serverRx) waitForPackets() {
+ s.eventFD.Wait()
+}
diff --git a/pkg/tcpip/link/sharedmem/server_tx.go b/pkg/tcpip/link/sharedmem/server_tx.go
new file mode 100644
index 000000000..13a82903f
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/server_tx.go
@@ -0,0 +1,175 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/eventfd"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
+)
+
+// serverTx represents the server end of the sharedmem queue and is used to send
+// packets to the peer in the buffers posted by the peer in the fillPipe.
+type serverTx struct {
+ // fillPipe represents the receive end of the pipe that carries the RxBuffers
+ // posted by the peer.
+ fillPipe pipe.Rx
+
+ // completionPipe represents the transmit end of the pipe that carries the
+ // descriptors for filled RxBuffers.
+ completionPipe pipe.Tx
+
+ // data represents the buffer area where the packet payload is held.
+ data []byte
+
+ // eventFD is used to notify the peer when fill requests are fulfilled.
+ eventFD eventfd.Eventfd
+
+ // sharedData the memory region to use to enable/disable notifications.
+ sharedData []byte
+}
+
+// init initializes all tstate needed by the serverTx queue based on the
+// information provided.
+//
+// The caller always retains ownership of all file descriptors passed in. The
+// queue implementation will duplicate any that it may need in the future.
+func (s *serverTx) init(c *QueueConfig) error {
+ // Map in all buffers.
+ fillPipeMem, err := getBuffer(c.TxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu := cleanup.Make(func() { unix.Munmap(fillPipeMem) })
+ defer cu.Clean()
+
+ completionPipeMem, err := getBuffer(c.RxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(completionPipeMem) })
+
+ data, err := getBuffer(c.DataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(data) })
+
+ sharedData, err := getBuffer(c.SharedDataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(sharedData) })
+
+ // Duplicate the eventFD so that caller can close it but we can still
+ // use it.
+ efd, err := c.EventFD.Dup()
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { efd.Close() })
+
+ cu.Release()
+
+ s.fillPipe.Init(fillPipeMem)
+ s.completionPipe.Init(completionPipeMem)
+ s.data = data
+ s.eventFD = efd
+ s.sharedData = sharedData
+
+ return nil
+}
+
+func (s *serverTx) cleanup() {
+ unix.Munmap(s.fillPipe.Bytes())
+ unix.Munmap(s.completionPipe.Bytes())
+ unix.Munmap(s.data)
+ unix.Munmap(s.sharedData)
+ s.eventFD.Close()
+}
+
+// fillPacket copies the data in the provided views into buffers pulled from the
+// fillPipe and returns a slice of RxBuffers that contain the copied data as
+// well as the total number of bytes copied.
+//
+// To avoid allocations the filledBuffers are appended to the buffers slice
+// which will be grown as required.
+func (s *serverTx) fillPacket(views []buffer.View, buffers []queue.RxBuffer) (filledBuffers []queue.RxBuffer, totalCopied uint32) {
+ filledBuffers = buffers[:0]
+ // fillBuffer copies as much of the views as possible into the provided buffer
+ // and returns any left over views (if any).
+ fillBuffer := func(buffer *queue.RxBuffer, views []buffer.View) (left []buffer.View) {
+ if len(views) == 0 {
+ return nil
+ }
+ availBytes := buffer.Size
+ copied := uint64(0)
+ for availBytes > 0 && len(views) > 0 {
+ n := copy(s.data[buffer.Offset+copied:][:uint64(buffer.Size)-copied], views[0])
+ views[0].TrimFront(n)
+ if !views[0].IsEmpty() {
+ break
+ }
+ views = views[1:]
+ copied += uint64(n)
+ availBytes -= uint32(n)
+ }
+ buffer.Size = uint32(copied)
+ return views
+ }
+
+ for len(views) > 0 {
+ var b []byte
+ // Spin till we get a free buffer reposted by the peer.
+ for {
+ if b = s.fillPipe.Pull(); b != nil {
+ break
+ }
+ }
+ rxBuffer := queue.DecodeRxBufferHeader(b)
+ // Copy the packet into the posted buffer.
+ views = fillBuffer(&rxBuffer, views)
+ totalCopied += rxBuffer.Size
+ filledBuffers = append(filledBuffers, rxBuffer)
+ }
+
+ return filledBuffers, totalCopied
+}
+
+func (s *serverTx) transmit(views []buffer.View) bool {
+ buffers := make([]queue.RxBuffer, 8)
+ buffers, totalCopied := s.fillPacket(views, buffers)
+ b := s.completionPipe.Push(queue.RxCompletionSize(len(buffers)))
+ if b == nil {
+ return false
+ }
+ queue.EncodeRxCompletion(b, totalCopied, 0 /* reserved */)
+ for i := 0; i < len(buffers); i++ {
+ queue.EncodeRxCompletionBuffer(b, i, buffers[i])
+ }
+ s.completionPipe.Flush()
+ s.fillPipe.Flush()
+ return true
+}
+
+func (s *serverTx) notify() {
+ s.eventFD.Notify()
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 30cf659b8..bcb37a465 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -24,14 +24,16 @@
package sharedmem
import (
+ "fmt"
"sync/atomic"
- "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/eventfd"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
"gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -47,7 +49,7 @@ type QueueConfig struct {
// EventFD is a file descriptor for the event that is signaled when
// data is becomes available in this queue.
- EventFD int
+ EventFD eventfd.Eventfd
// TxPipeFD is a file descriptor for the tx pipe associated with the
// queue.
@@ -63,16 +65,97 @@ type QueueConfig struct {
SharedDataFD int
}
+// FDs returns the FD's in the QueueConfig as a slice of ints. This must
+// be used in conjunction with QueueConfigFromFDs to ensure the order
+// of FDs matches when reconstructing the config when serialized or sent
+// as part of control messages.
+func (q *QueueConfig) FDs() []int {
+ return []int{q.DataFD, q.EventFD.FD(), q.TxPipeFD, q.RxPipeFD, q.SharedDataFD}
+}
+
+// QueueConfigFromFDs constructs a QueueConfig out of a slice of ints where each
+// entry represents an file descriptor. The order of FDs in the slice must be in
+// the order specified below for the config to be valid. QueueConfig.FDs()
+// should be used when the config needs to be serialized or sent as part of a
+// control message to ensure the correct order.
+func QueueConfigFromFDs(fds []int) (QueueConfig, error) {
+ if len(fds) != 5 {
+ return QueueConfig{}, fmt.Errorf("insufficient number of fds: len(fds): %d, want: 5", len(fds))
+ }
+ return QueueConfig{
+ DataFD: fds[0],
+ EventFD: eventfd.Wrap(fds[1]),
+ TxPipeFD: fds[2],
+ RxPipeFD: fds[3],
+ SharedDataFD: fds[4],
+ }, nil
+}
+
+// Options specify the details about the sharedmem endpoint to be created.
+type Options struct {
+ // MTU is the mtu to use for this endpoint.
+ MTU uint32
+
+ // BufferSize is the size of each scatter/gather buffer that will hold packet
+ // data.
+ //
+ // NOTE: This directly determines number of packets that can be held in
+ // the ring buffer at any time. This does not have to be sized to the MTU as
+ // the shared memory queue design allows usage of more than one buffer to be
+ // used to make up a given packet.
+ BufferSize uint32
+
+ // LinkAddress is the link address for this endpoint (required).
+ LinkAddress tcpip.LinkAddress
+
+ // TX is the transmit queue configuration for this shared memory endpoint.
+ TX QueueConfig
+
+ // RX is the receive queue configuration for this shared memory endpoint.
+ RX QueueConfig
+
+ // PeerFD is the fd for the connected peer which can be used to detect
+ // peer disconnects.
+ PeerFD int
+
+ // OnClosed is a function that is called when the endpoint is being closed
+ // (probably due to peer going away)
+ OnClosed func(err tcpip.Error)
+
+ // TXChecksumOffload if true, indicates that this endpoints capability
+ // set should include CapabilityTXChecksumOffload.
+ TXChecksumOffload bool
+
+ // RXChecksumOffload if true, indicates that this endpoints capability
+ // set should include CapabilityRXChecksumOffload.
+ RXChecksumOffload bool
+}
+
type endpoint struct {
// mtu (maximum transmission unit) is the maximum size of a packet.
+ // mtu is immutable.
mtu uint32
// bufferSize is the size of each individual buffer.
+ // bufferSize is immutable.
bufferSize uint32
// addr is the local address of this endpoint.
+ // addr is immutable.
addr tcpip.LinkAddress
+ // peerFD is an fd to the peer that can be used to detect when the
+ // peer is gone.
+ // peerFD is immutable.
+ peerFD int
+
+ // caps holds the endpoint capabilities.
+ caps stack.LinkEndpointCapabilities
+
+ // hdrSize is the size of the link layer header if any.
+ // hdrSize is immutable.
+ hdrSize uint32
+
// rx is the receive queue.
rx rx
@@ -83,34 +166,55 @@ type endpoint struct {
// Wait group used to indicate that all workers have stopped.
completed sync.WaitGroup
+ // onClosed is a function to be called when the FD's peer (if any) closes
+ // its end of the communication pipe.
+ onClosed func(tcpip.Error)
+
// mu protects the following fields.
mu sync.Mutex
// tx is the transmit queue.
+ // +checklocks:mu
tx tx
// workerStarted specifies whether the worker goroutine was started.
+ // +checklocks:mu
workerStarted bool
}
// New creates a new shared-memory-based endpoint. Buffers will be broken up
// into buffers of "bufferSize" bytes.
-func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (stack.LinkEndpoint, error) {
+func New(opts Options) (stack.LinkEndpoint, error) {
e := &endpoint{
- mtu: mtu,
- bufferSize: bufferSize,
- addr: addr,
+ mtu: opts.MTU,
+ bufferSize: opts.BufferSize,
+ addr: opts.LinkAddress,
+ peerFD: opts.PeerFD,
+ onClosed: opts.OnClosed,
}
- if err := e.tx.init(bufferSize, &tx); err != nil {
+ if err := e.tx.init(opts.BufferSize, &opts.TX); err != nil {
return nil, err
}
- if err := e.rx.init(bufferSize, &rx); err != nil {
+ if err := e.rx.init(opts.BufferSize, &opts.RX); err != nil {
e.tx.cleanup()
return nil, err
}
+ e.caps = stack.LinkEndpointCapabilities(0)
+ if opts.RXChecksumOffload {
+ e.caps |= stack.CapabilityRXChecksumOffload
+ }
+
+ if opts.TXChecksumOffload {
+ e.caps |= stack.CapabilityTXChecksumOffload
+ }
+
+ if opts.LinkAddress != "" {
+ e.hdrSize = header.EthernetMinimumSize
+ e.caps |= stack.CapabilityResolutionRequired
+ }
return e, nil
}
@@ -119,13 +223,13 @@ func (e *endpoint) Close() {
// Tell dispatch goroutine to stop, then write to the eventfd so that
// it wakes up in case it's sleeping.
atomic.StoreUint32(&e.stopRequested, 1)
- unix.Write(e.rx.eventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ e.rx.eventFD.Notify()
// Cleanup the queues inline if the worker hasn't started yet; we also
// know it won't start from now on because stopRequested is set to 1.
e.mu.Lock()
+ defer e.mu.Unlock()
workerPresent := e.workerStarted
- e.mu.Unlock()
if !workerPresent {
e.tx.cleanup()
@@ -146,6 +250,22 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 {
e.workerStarted = true
e.completed.Add(1)
+
+ // Spin up a goroutine to monitor for peer shutdown.
+ if e.peerFD >= 0 {
+ e.completed.Add(1)
+ go func() {
+ defer e.completed.Done()
+ b := make([]byte, 1)
+ // When sharedmem endpoint is in use the peerFD is never used for any data
+ // transfer and this Read should only return if the peer is shutting down.
+ _, err := rawfile.BlockingRead(e.peerFD, b)
+ if e.onClosed != nil {
+ e.onClosed(err)
+ }
+ }()
+ }
+
// Link endpoints are not savable. When transportation endpoints
// are saved, they stop sending outgoing packets and all
// incoming packets are rejected.
@@ -164,18 +284,18 @@ func (e *endpoint) IsAttached() bool {
// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
// during construction.
func (e *endpoint) MTU() uint32 {
- return e.mtu - header.EthernetMinimumSize
+ return e.mtu - e.hdrSize
}
// Capabilities implements stack.LinkEndpoint.Capabilities.
-func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return 0
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.caps
}
// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the
// ethernet frame header size.
-func (*endpoint) MaxHeaderLength() uint16 {
- return header.EthernetMinimumSize
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return uint16(e.hdrSize)
}
// LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local
@@ -202,17 +322,18 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
eth.Encode(ethHdr)
}
-// WritePacket writes outbound packets to the file descriptor. If it is not
-// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
+
+// +checklocks:e.mu
+func (e *endpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ if e.addr != "" {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ }
views := pkt.Views()
// Transmit the packet.
- e.mu.Lock()
ok := e.tx.transmit(views...)
- e.mu.Unlock()
-
if !ok {
return &tcpip.ErrWouldBlock{}
}
@@ -220,9 +341,37 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol
return nil
}
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if err := e.writePacketLocked(r, protocol, pkt); err != nil {
+ return err
+ }
+ e.tx.notify()
+ return nil
+}
+
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (*endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- panic("not implemented")
+func (e *endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+ n := 0
+ var err tcpip.Error
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err = e.writePacketLocked(r, pkt.NetworkProtocolNumber, pkt); err != nil {
+ break
+ }
+ n++
+ }
+ // WritePackets never returns an error if it successfully transmitted at least
+ // one packet.
+ if err != nil && n == 0 {
+ return 0, err
+ }
+ e.tx.notify()
+ return n, nil
}
// dispatchLoop reads packets from the rx queue in a loop and dispatches them
@@ -265,16 +414,42 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
Data: buffer.View(b).ToVectorisedView(),
})
- hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
- if !ok {
- continue
+ var src, dst tcpip.LinkAddress
+ var proto tcpip.NetworkProtocolNumber
+ if e.addr != "" {
+ hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
+ if !ok {
+ continue
+ }
+ eth := header.Ethernet(hdr)
+ src = eth.SourceAddress()
+ dst = eth.DestinationAddress()
+ proto = eth.Type()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ // IP version information is at the first octet, so pulling up 1 byte.
+ h, ok := pkt.Data().PullUp(1)
+ if !ok {
+ continue
+ }
+ switch header.IPVersion(h) {
+ case header.IPv4Version:
+ proto = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ proto = header.IPv6ProtocolNumber
+ default:
+ continue
+ }
}
- eth := header.Ethernet(hdr)
// Send packet up the stack.
- d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), pkt)
+ d.DeliverNetworkPacket(src, dst, proto, pkt)
}
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
// Clean state.
e.tx.cleanup()
e.rx.cleanup()
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_server.go b/pkg/tcpip/link/sharedmem/sharedmem_server.go
new file mode 100644
index 000000000..ccc84989d
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/sharedmem_server.go
@@ -0,0 +1,333 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type serverEndpoint struct {
+ // mtu (maximum transmission unit) is the maximum size of a packet.
+ // mtu is immutable.
+ mtu uint32
+
+ // bufferSize is the size of each individual buffer.
+ // bufferSize is immutable.
+ bufferSize uint32
+
+ // addr is the local address of this endpoint.
+ // addr is immutable
+ addr tcpip.LinkAddress
+
+ // rx is the receive queue.
+ rx serverRx
+
+ // stopRequested is to be accessed atomically only, and determines if the
+ // worker goroutines should stop.
+ stopRequested uint32
+
+ // Wait group used to indicate that all workers have stopped.
+ completed sync.WaitGroup
+
+ // peerFD is an fd to the peer that can be used to detect when the peer is
+ // gone.
+ // peerFD is immutable.
+ peerFD int
+
+ // caps holds the endpoint capabilities.
+ caps stack.LinkEndpointCapabilities
+
+ // hdrSize is the size of the link layer header if any.
+ // hdrSize is immutable.
+ hdrSize uint32
+
+ // onClosed is a function to be called when the FD's peer (if any) closes its
+ // end of the communication pipe.
+ onClosed func(tcpip.Error)
+
+ // mu protects the following fields.
+ mu sync.Mutex
+
+ // tx is the transmit queue.
+ // +checklocks:mu
+ tx serverTx
+
+ // workerStarted specifies whether the worker goroutine was started.
+ // +checklocks:mu
+ workerStarted bool
+}
+
+// NewServerEndpoint creates a new shared-memory-based endpoint. Buffers will be
+// broken up into buffers of "bufferSize" bytes.
+func NewServerEndpoint(opts Options) (stack.LinkEndpoint, error) {
+ e := &serverEndpoint{
+ mtu: opts.MTU,
+ bufferSize: opts.BufferSize,
+ addr: opts.LinkAddress,
+ peerFD: opts.PeerFD,
+ onClosed: opts.OnClosed,
+ }
+
+ if err := e.tx.init(&opts.RX); err != nil {
+ return nil, err
+ }
+
+ if err := e.rx.init(&opts.TX); err != nil {
+ e.tx.cleanup()
+ return nil, err
+ }
+
+ e.caps = stack.LinkEndpointCapabilities(0)
+ if opts.RXChecksumOffload {
+ e.caps |= stack.CapabilityRXChecksumOffload
+ }
+
+ if opts.TXChecksumOffload {
+ e.caps |= stack.CapabilityTXChecksumOffload
+ }
+
+ if opts.LinkAddress != "" {
+ e.hdrSize = header.EthernetMinimumSize
+ e.caps |= stack.CapabilityResolutionRequired
+ }
+
+ return e, nil
+}
+
+// Close frees all resources associated with the endpoint.
+func (e *serverEndpoint) Close() {
+ // Tell dispatch goroutine to stop, then write to the eventfd so that it wakes
+ // up in case it's sleeping.
+ atomic.StoreUint32(&e.stopRequested, 1)
+ e.rx.eventFD.Notify()
+
+ // Cleanup the queues inline if the worker hasn't started yet; we also know it
+ // won't start from now on because stopRequested is set to 1.
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ workerPresent := e.workerStarted
+
+ if !workerPresent {
+ e.tx.cleanup()
+ e.rx.cleanup()
+ }
+}
+
+// Wait implements stack.LinkEndpoint.Wait. It waits until all workers have
+// stopped after a Close() call.
+func (e *serverEndpoint) Wait() {
+ e.completed.Wait()
+}
+
+// Attach implements stack.LinkEndpoint.Attach. It launches the goroutine that
+// reads packets from the rx queue.
+func (e *serverEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.mu.Lock()
+ if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 {
+ e.workerStarted = true
+ e.completed.Add(1)
+ if e.peerFD >= 0 {
+ e.completed.Add(1)
+ // Spin up a goroutine to monitor for peer shutdown.
+ go func() {
+ b := make([]byte, 1)
+ // When sharedmem endpoint is in use the peerFD is never used for any
+ // data transfer and this Read should only return if the peer is
+ // shutting down.
+ _, err := rawfile.BlockingRead(e.peerFD, b)
+ if e.onClosed != nil {
+ e.onClosed(err)
+ }
+ e.completed.Done()
+ }()
+ }
+ // Link endpoints are not savable. When transportation endpoints are saved,
+ // they stop sending outgoing packets and all incoming packets are rejected.
+ go e.dispatchLoop(dispatcher) // S/R-SAFE: see above.
+ }
+ e.mu.Unlock()
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *serverEndpoint) IsAttached() bool {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.workerStarted
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
+// during construction.
+func (e *serverEndpoint) MTU() uint32 {
+ return e.mtu - e.hdrSize
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (e *serverEndpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.caps
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the
+// ethernet frame header size.
+func (e *serverEndpoint) MaxHeaderLength() uint16 {
+ return uint16(e.hdrSize)
+}
+
+// LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local
+// link address.
+func (e *serverEndpoint) LinkAddress() tcpip.LinkAddress {
+ return e.addr
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *serverEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ // Add ethernet header if needed.
+ eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
+ ethHdr := &header.EthernetFields{
+ DstAddr: remote,
+ Type: protocol,
+ }
+
+ // Preserve the src address if it's set in the route.
+ if local != "" {
+ ethHdr.SrcAddr = local
+ } else {
+ ethHdr.SrcAddr = e.addr
+ }
+ eth.Encode(ethHdr)
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*serverEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
+// +checklocks:e.mu
+func (e *serverEndpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+
+ views := pkt.Views()
+ ok := e.tx.transmit(views)
+ if !ok {
+ return &tcpip.ErrWouldBlock{}
+ }
+
+ return nil
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *serverEndpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ // Transmit the packet.
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if err := e.writePacketLocked(r, protocol, pkt); err != nil {
+ return err
+ }
+ e.tx.notify()
+ return nil
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *serverEndpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+ n := 0
+ var err tcpip.Error
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err = e.writePacketLocked(r, pkt.NetworkProtocolNumber, pkt); err != nil {
+ break
+ }
+ n++
+ }
+ // WritePackets never returns an error if it successfully transmitted at least
+ // one packet.
+ if err != nil && n == 0 {
+ return 0, err
+ }
+ e.tx.notify()
+ return n, nil
+}
+
+// dispatchLoop reads packets from the rx queue in a loop and dispatches them
+// to the network stack.
+func (e *serverEndpoint) dispatchLoop(d stack.NetworkDispatcher) {
+ for atomic.LoadUint32(&e.stopRequested) == 0 {
+ b := e.rx.receive()
+ if b == nil {
+ e.rx.waitForPackets()
+ continue
+ }
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.View(b).ToVectorisedView(),
+ })
+ var src, dst tcpip.LinkAddress
+ var proto tcpip.NetworkProtocolNumber
+ if e.addr != "" {
+ hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
+ if !ok {
+ continue
+ }
+ eth := header.Ethernet(hdr)
+ src = eth.SourceAddress()
+ dst = eth.DestinationAddress()
+ proto = eth.Type()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ // IP version information is at the first octet, so pulling up 1 byte.
+ h, ok := pkt.Data().PullUp(1)
+ if !ok {
+ continue
+ }
+ switch header.IPVersion(h) {
+ case header.IPv4Version:
+ proto = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ proto = header.IPv6ProtocolNumber
+ default:
+ continue
+ }
+ }
+ // Send packet up the stack.
+ d.DeliverNetworkPacket(src, dst, proto, pkt)
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Clean state.
+ e.tx.cleanup()
+ e.rx.cleanup()
+
+ e.completed.Done()
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
+func (e *serverEndpoint) ARPHardwareType() header.ARPHardwareType {
+ if e.hdrSize > 0 {
+ return header.ARPHardwareEther
+ }
+ return header.ARPHardwareNone
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_server_test.go b/pkg/tcpip/link/sharedmem/sharedmem_server_test.go
new file mode 100644
index 000000000..1bc58614e
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/sharedmem_server_test.go
@@ -0,0 +1,220 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem_server_test
+
+import (
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "syscall"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+)
+
+const (
+ localLinkAddr = "\xde\xad\xbe\xef\x56\x78"
+ remoteLinkAddr = "\xde\xad\xbe\xef\x12\x34"
+ localIPv4Address = tcpip.Address("\x0a\x00\x00\x01")
+ remoteIPv4Address = tcpip.Address("\x0a\x00\x00\x02")
+ serverPort = 10001
+
+ defaultMTU = 1500
+ defaultBufferSize = 1500
+)
+
+type stackOptions struct {
+ ep stack.LinkEndpoint
+ addr tcpip.Address
+}
+
+func newStackWithOptions(stackOpts stackOptions) (*stack.Stack, error) {
+ st := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocolWithOptions(ipv4.Options{
+ AllowExternalLoopbackTraffic: true,
+ }),
+ ipv6.NewProtocolWithOptions(ipv6.Options{
+ AllowExternalLoopbackTraffic: true,
+ }),
+ },
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
+ })
+ nicID := tcpip.NICID(1)
+ sniffEP := sniffer.New(stackOpts.ep)
+ opts := stack.NICOptions{Name: "eth0"}
+ if err := st.CreateNICWithOptions(nicID, sniffEP, opts); err != nil {
+ return nil, fmt.Errorf("method CreateNICWithOptions(%d, _, %v) failed: %s", nicID, opts, err)
+ }
+
+ // Add Protocol Address.
+ protocolNum := ipv4.ProtocolNumber
+ routeTable := []tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}
+ if len(stackOpts.addr) == 16 {
+ routeTable = []tcpip.Route{{Destination: header.IPv6EmptySubnet, NIC: nicID}}
+ protocolNum = ipv6.ProtocolNumber
+ }
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: protocolNum,
+ AddressWithPrefix: stackOpts.addr.WithPrefix(),
+ }
+ if err := st.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, fmt.Errorf("AddProtocolAddress(%d, %v, {}): %s", nicID, protocolAddr, err)
+ }
+
+ // Setup route table.
+ st.SetRouteTable(routeTable)
+
+ return st, nil
+}
+
+func newClientStack(t *testing.T, qPair *sharedmem.QueuePair, peerFD int) (*stack.Stack, error) {
+ ep, err := sharedmem.New(sharedmem.Options{
+ MTU: defaultMTU,
+ BufferSize: defaultBufferSize,
+ LinkAddress: localLinkAddr,
+ TX: qPair.TXQueueConfig(),
+ RX: qPair.RXQueueConfig(),
+ PeerFD: peerFD,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed to create sharedmem endpoint: %s", err)
+ }
+ st, err := newStackWithOptions(stackOptions{ep: ep, addr: localIPv4Address})
+ if err != nil {
+ return nil, fmt.Errorf("failed to create client stack: %s", err)
+ }
+ return st, nil
+}
+
+func newServerStack(t *testing.T, qPair *sharedmem.QueuePair, peerFD int) (*stack.Stack, error) {
+ ep, err := sharedmem.NewServerEndpoint(sharedmem.Options{
+ MTU: defaultMTU,
+ BufferSize: defaultBufferSize,
+ LinkAddress: remoteLinkAddr,
+ TX: qPair.TXQueueConfig(),
+ RX: qPair.RXQueueConfig(),
+ PeerFD: peerFD,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed to create sharedmem endpoint: %s", err)
+ }
+ st, err := newStackWithOptions(stackOptions{ep: ep, addr: remoteIPv4Address})
+ if err != nil {
+ return nil, fmt.Errorf("failed to create client stack: %s", err)
+ }
+ return st, nil
+}
+
+type testContext struct {
+ clientStk *stack.Stack
+ serverStk *stack.Stack
+ peerFDs [2]int
+}
+
+func newTestContext(t *testing.T) *testContext {
+ peerFDs, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET|syscall.SOCK_NONBLOCK, 0)
+ if err != nil {
+ t.Fatalf("failed to create peerFDs: %s", err)
+ }
+ q, err := sharedmem.NewQueuePair()
+ if err != nil {
+ t.Fatalf("failed to create sharedmem queue: %s", err)
+ }
+ clientStack, err := newClientStack(t, q, peerFDs[0])
+ if err != nil {
+ q.Close()
+ unix.Close(peerFDs[0])
+ unix.Close(peerFDs[1])
+ t.Fatalf("failed to create client stack: %s", err)
+ }
+ serverStack, err := newServerStack(t, q, peerFDs[1])
+ if err != nil {
+ q.Close()
+ unix.Close(peerFDs[0])
+ unix.Close(peerFDs[1])
+ clientStack.Close()
+ t.Fatalf("failed to create server stack: %s", err)
+ }
+ return &testContext{
+ clientStk: clientStack,
+ serverStk: serverStack,
+ peerFDs: peerFDs,
+ }
+}
+
+func (ctx *testContext) cleanup() {
+ unix.Close(ctx.peerFDs[0])
+ unix.Close(ctx.peerFDs[1])
+ ctx.clientStk.Close()
+ ctx.serverStk.Close()
+}
+
+func TestServerRoundTrip(t *testing.T) {
+ ctx := newTestContext(t)
+ defer ctx.cleanup()
+ listenAddr := tcpip.FullAddress{Addr: remoteIPv4Address, Port: serverPort}
+ l, err := gonet.ListenTCP(ctx.serverStk, listenAddr, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatalf("failed to start TCP Listener: %s", err)
+ }
+ defer l.Close()
+ var responseString = "response"
+ go func() {
+ http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(responseString))
+ }))
+ }()
+
+ dialFunc := func(address, protocol string) (net.Conn, error) {
+ return gonet.DialTCP(ctx.clientStk, listenAddr, ipv4.ProtocolNumber)
+ }
+
+ httpClient := &http.Client{
+ Transport: &http.Transport{
+ Dial: dialFunc,
+ },
+ }
+ serverURL := fmt.Sprintf("http://[%s]:%d/", net.IP(remoteIPv4Address), serverPort)
+ response, err := httpClient.Get(serverURL)
+ if err != nil {
+ t.Fatalf("httpClient.Get(\"/\") failed: %s", err)
+ }
+ if got, want := response.StatusCode, http.StatusOK; got != want {
+ t.Fatalf("unexpected status code got: %d, want: %d", got, want)
+ }
+ body, err := io.ReadAll(response.Body)
+ if err != nil {
+ t.Fatalf("io.ReadAll(response.Body) failed: %s", err)
+ }
+ response.Body.Close()
+ if got, want := string(body), responseString; got != want {
+ t.Fatalf("unexpected response got: %s, want: %s", got, want)
+ }
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index d6d953085..66ffc33b8 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -19,9 +19,7 @@ package sharedmem
import (
"bytes"
- "io/ioutil"
"math/rand"
- "os"
"strings"
"testing"
"time"
@@ -104,24 +102,36 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress
t: t,
packetCh: make(chan struct{}, 1000000),
}
- c.txCfg = createQueueFDs(t, queueSizes{
+ c.txCfg, err = createQueueFDs(queueSizes{
dataSize: queueDataSize,
txPipeSize: queuePipeSize,
rxPipeSize: queuePipeSize,
sharedDataSize: 4096,
})
-
- c.rxCfg = createQueueFDs(t, queueSizes{
+ if err != nil {
+ t.Fatalf("createQueueFDs for tx failed: %s", err)
+ }
+ c.rxCfg, err = createQueueFDs(queueSizes{
dataSize: queueDataSize,
txPipeSize: queuePipeSize,
rxPipeSize: queuePipeSize,
sharedDataSize: 4096,
})
+ if err != nil {
+ t.Fatalf("createQueueFDs for rx failed: %s", err)
+ }
initQueue(t, &c.txq, &c.txCfg)
initQueue(t, &c.rxq, &c.rxCfg)
- ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg)
+ ep, err := New(Options{
+ MTU: mtu,
+ BufferSize: bufferSize,
+ LinkAddress: addr,
+ TX: c.txCfg,
+ RX: c.rxCfg,
+ PeerFD: -1,
+ })
if err != nil {
t.Fatalf("New failed: %v", err)
}
@@ -150,8 +160,8 @@ func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.
func (c *testContext) cleanup() {
c.ep.Close()
- closeFDs(&c.txCfg)
- closeFDs(&c.rxCfg)
+ closeFDs(c.txCfg)
+ closeFDs(c.rxCfg)
c.txq.cleanup()
c.rxq.cleanup()
}
@@ -191,69 +201,6 @@ func shuffle(b []int) {
}
}
-func createFile(t *testing.T, size int64, initQueue bool) int {
- tmpDir, ok := os.LookupEnv("TEST_TMPDIR")
- if !ok {
- tmpDir = os.Getenv("TMPDIR")
- }
- f, err := ioutil.TempFile(tmpDir, "sharedmem_test")
- if err != nil {
- t.Fatalf("TempFile failed: %v", err)
- }
- defer f.Close()
- unix.Unlink(f.Name())
-
- if initQueue {
- // Write the "slot-free" flag in the initial queue.
- _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0)
- if err != nil {
- t.Fatalf("WriteAt failed: %v", err)
- }
- }
-
- fd, err := unix.Dup(int(f.Fd()))
- if err != nil {
- t.Fatalf("Dup failed: %v", err)
- }
-
- if err := unix.Ftruncate(fd, size); err != nil {
- unix.Close(fd)
- t.Fatalf("Ftruncate failed: %v", err)
- }
-
- return fd
-}
-
-func closeFDs(c *QueueConfig) {
- unix.Close(c.DataFD)
- unix.Close(c.EventFD)
- unix.Close(c.TxPipeFD)
- unix.Close(c.RxPipeFD)
- unix.Close(c.SharedDataFD)
-}
-
-type queueSizes struct {
- dataSize int64
- txPipeSize int64
- rxPipeSize int64
- sharedDataSize int64
-}
-
-func createQueueFDs(t *testing.T, s queueSizes) QueueConfig {
- fd, _, err := unix.RawSyscall(unix.SYS_EVENTFD2, 0, 0, 0)
- if err != 0 {
- t.Fatalf("eventfd failed: %v", error(err))
- }
-
- return QueueConfig{
- EventFD: int(fd),
- DataFD: createFile(t, s.dataSize, false),
- TxPipeFD: createFile(t, s.txPipeSize, true),
- RxPipeFD: createFile(t, s.rxPipeSize, true),
- SharedDataFD: createFile(t, s.sharedDataSize, false),
- }
-}
-
// TestSimpleSend sends 1000 packets with random header and payload sizes,
// then checks that the right payload is received on the shared memory queues.
func TestSimpleSend(t *testing.T) {
@@ -672,7 +619,7 @@ func TestSimpleReceive(t *testing.T) {
// Push completion.
c.pushRxCompletion(uint32(len(contents)), bufs)
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Wait for packet to be received, then check it.
c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet")
@@ -718,7 +665,7 @@ func TestRxBuffersReposted(t *testing.T) {
// Complete the buffer.
c.pushRxCompletion(buffers[i].Size, buffers[i:][:1])
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Wait for it to be reposted.
bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted"))
@@ -734,7 +681,7 @@ func TestRxBuffersReposted(t *testing.T) {
// Complete with two buffers.
c.pushRxCompletion(2*bufferSize, buffers[2*i:][:2])
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Wait for them to be reposted.
for j := 0; j < 2; j++ {
@@ -759,7 +706,7 @@ func TestReceivePostingIsFull(t *testing.T) {
first := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for first buffer to be posted"))
c.pushRxCompletion(first.Size, []queue.RxBuffer{first})
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Check that packet is received.
c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet")
@@ -768,7 +715,7 @@ func TestReceivePostingIsFull(t *testing.T) {
second := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for second buffer to be posted"))
c.pushRxCompletion(second.Size, []queue.RxBuffer{second})
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Check that no packet is received yet, as the worker is blocked trying
// to repost.
@@ -781,7 +728,7 @@ func TestReceivePostingIsFull(t *testing.T) {
// Flush tx queue, which will allow the first buffer to be reposted,
// and the second completion to be pulled.
c.rxq.tx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Check that second packet completes.
c.waitForPackets(1, time.After(time.Second), "Timeout waiting for second completed packet")
@@ -803,7 +750,7 @@ func TestCloseWhileWaitingToPost(t *testing.T) {
bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for initial buffer to be posted"))
c.pushRxCompletion(bi.Size, []queue.RxBuffer{bi})
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Wait for packet to be indicated.
c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet")
diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go
index e3210051f..35e5bff12 100644
--- a/pkg/tcpip/link/sharedmem/tx.go
+++ b/pkg/tcpip/link/sharedmem/tx.go
@@ -18,6 +18,7 @@ import (
"math"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/eventfd"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
)
@@ -28,10 +29,12 @@ const (
// tx holds all state associated with a tx queue.
type tx struct {
- data []byte
- q queue.Tx
- ids idManager
- bufs bufferManager
+ data []byte
+ q queue.Tx
+ ids idManager
+ bufs bufferManager
+ eventFD eventfd.Eventfd
+ sharedDataFD int
}
// init initializes all state needed by the tx queue based on the information
@@ -64,7 +67,8 @@ func (t *tx) init(mtu uint32, c *QueueConfig) error {
t.ids.init()
t.bufs.init(0, len(data), int(mtu))
t.data = data
-
+ t.eventFD = c.EventFD
+ t.sharedDataFD = c.SharedDataFD
return nil
}
@@ -142,6 +146,12 @@ func (t *tx) transmit(bufs ...buffer.View) bool {
return true
}
+// notify writes to the tx.eventFD to indicate to the peer that there is data to
+// be read.
+func (t *tx) notify() {
+ t.eventFD.Notify()
+}
+
// getBuffer returns a memory region mapped to the full contents of the given
// file descriptor.
func getBuffer(fd int) ([]byte, error) {
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 28a172e71..2afa95af0 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -140,11 +140,6 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt)
}
-// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
-func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt)
-}
-
func (e *endpoint) dumpPacket(dir direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
writer := e.writer
if writer == nil && atomic.LoadUint32(&LogPackets) == 1 {
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
index 4758a99ad..c3e4c3455 100644
--- a/pkg/tcpip/link/tun/BUILD
+++ b/pkg/tcpip/link/tun/BUILD
@@ -31,7 +31,6 @@ go_library(
"//pkg/refs",
"//pkg/refsvfs2",
"//pkg/sync",
- "//pkg/syserror",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index d23210503..fa2131c28 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -20,7 +20,6 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -174,7 +173,7 @@ func (d *Device) Write(data []byte) (int64, error) {
return 0, linuxerr.EBADFD
}
if !endpoint.IsAttached() {
- return 0, syserror.EIO
+ return 0, linuxerr.EIO
}
dataLen := int64(len(data))
@@ -249,7 +248,7 @@ func (d *Device) Read() ([]byte, error) {
for {
info, ok := endpoint.Read()
if !ok {
- return nil, syserror.ErrWouldBlock
+ return nil, linuxerr.ErrWouldBlock
}
v, ok := d.encodePkt(&info)
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index a95602aa5..116e4defb 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -59,15 +59,6 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.dispatchGate.Leave()
}
-// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
-func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- if !e.dispatchGate.Enter() {
- return
- }
- e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt)
- e.dispatchGate.Leave()
-}
-
// Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and
// registers with the lower endpoint as its dispatcher so that "e" is called
// for inbound packets.
@@ -155,3 +146,6 @@ func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.lower.AddHeader(local, remote, protocol, pkt)
}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*Endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index a71400ee9..b0e4237bd 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -80,6 +80,11 @@ func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBuffe
return pkts.Len(), nil
}
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*countedEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType {
panic("unimplemented")
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index 7b1ff44f4..c0179104a 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -23,8 +23,10 @@ go_test(
"//pkg/tcpip/stack",
"//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 6515c31e5..e08243547 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -272,7 +272,6 @@ type protocol struct {
func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
func (p *protocol) MinimumPacketSize() int { return header.ARPSize }
-func (p *protocol) DefaultPrefixLen() int { return 0 }
func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) {
return "", ""
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 5fcbfeaa2..061cc35ae 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -153,8 +153,12 @@ func makeTestContext(t *testing.T, eventDepth int, packetDepth int) testContext
t.Fatalf("CreateNIC failed: %s", err)
}
- if err := tc.s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress for ipv4 failed: %s", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: stackAddr.WithPrefix(),
+ }
+ if err := tc.s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
tc.s.SetRouteTable([]tcpip.Route{{
@@ -569,8 +573,12 @@ func TestLinkAddressRequest(t *testing.T) {
}
if len(test.nicAddr) != 0 {
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: test.nicAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go
index 605e9ef8d..4d4d98caf 100644
--- a/pkg/tcpip/network/internal/testutil/testutil.go
+++ b/pkg/tcpip/network/internal/testutil/testutil.go
@@ -101,6 +101,11 @@ func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return heade
func (*MockLinkEndpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
}
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*MockLinkEndpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
// MakeRandPkt generates a randomized packet. transportHeaderLength indicates
// how many random bytes will be copied in the Transport Header.
// extraHeaderReserveLength indicates how much extra space will be reserved for
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 771b9173a..87f650661 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -15,6 +15,7 @@
package ip_test
import (
+ "bytes"
"fmt"
"strings"
"testing"
@@ -32,8 +33,10 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/raw"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
)
const nicID = 1
@@ -230,7 +233,13 @@ func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
})
s.CreateNIC(nicID, loopback.New())
- s.AddAddress(nicID, ipv4.ProtocolNumber, local)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: local.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, err
+ }
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
Gateway: ipv4Gateway,
@@ -246,7 +255,13 @@ func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
})
s.CreateNIC(nicID, loopback.New())
- s.AddAddress(nicID, ipv6.ProtocolNumber, local)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: local.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, err
+ }
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv6EmptySubnet,
Gateway: ipv6Gateway,
@@ -269,13 +284,13 @@ func buildDummyStackWithLinkEndpoint(t *testing.T, mtu uint32) (*stack.Stack, *c
}
v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix}
- if err := s.AddProtocolAddress(nicID, v4Addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err)
+ if err := s.AddProtocolAddress(nicID, v4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v4Addr, err)
}
v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix}
- if err := s.AddProtocolAddress(nicID, v6Addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err)
+ if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v6Addr, err)
}
return s, e
@@ -710,8 +725,8 @@ func TestReceive(t *testing.T) {
if !ok {
t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum)
}
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", test.epAddr, err)
} else {
ep.DecRef()
}
@@ -882,8 +897,8 @@ func TestIPv4ReceiveControl(t *testing.T) {
t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
}
addr := localIPv4Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
@@ -968,8 +983,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
}
addr := localIPv4Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
@@ -1234,8 +1249,8 @@ func TestIPv6ReceiveControl(t *testing.T) {
t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint")
}
addr := localIPv6Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
@@ -1301,7 +1316,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name string
protoFactory stack.NetworkProtocolFactory
protoNum tcpip.NetworkProtocolNumber
- nicAddr tcpip.Address
+ nicAddr tcpip.AddressWithPrefix
remoteAddr tcpip.Address
pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView
checker func(*testing.T, *stack.PacketBuffer, tcpip.Address)
@@ -1311,7 +1326,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv4MinimumSize + len(data)
@@ -1352,7 +1367,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 with IHL too small",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv4MinimumSize + len(data)
@@ -1376,7 +1391,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 too small",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
@@ -1394,7 +1409,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 minimum size",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
@@ -1430,7 +1445,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 with options",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
@@ -1475,7 +1490,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 with options and data across views",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length()))
@@ -1516,7 +1531,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv6MinimumSize + len(data)
@@ -1556,7 +1571,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6 with extension header",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data)
@@ -1601,7 +1616,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6 minimum size",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
@@ -1636,7 +1651,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6 too small",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
@@ -1660,11 +1675,11 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
}{
{
name: "unspecified source",
- srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))),
+ srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr.Address))),
},
{
name: "random source",
- srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))),
+ srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr.Address))),
},
}
@@ -1677,15 +1692,19 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: test.protoNum,
+ AddressWithPrefix: test.nicAddr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}})
- r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */)
+ r, err := s.FindRoute(nicID, test.nicAddr.Address, test.remoteAddr, test.protoNum, false /* multicastLoop */)
if err != nil {
- t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err)
+ t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr.Address, test.protoNum, err)
}
defer r.Release()
@@ -2032,3 +2051,97 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) {
})
}
}
+
+func TestSetNICIDBeforeDeliveringToRawEndpoint(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ addr tcpip.AddressWithPrefix
+ payloadOffset int
+ }{
+ {
+ name: "IPv4",
+ proto: header.IPv4ProtocolNumber,
+ addr: localIPv4AddrWithPrefix,
+ payloadOffset: header.IPv4MinimumSize,
+ },
+ {
+ name: "IPv6",
+ proto: header.IPv6ProtocolNumber,
+ addr: localIPv6AddrWithPrefix,
+ payloadOffset: 0,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocol,
+ ipv6.NewProtocol,
+ },
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ RawFactory: raw.EndpointFactory{},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: test.proto,
+ AddressWithPrefix: test.addr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: test.addr.Subnet(),
+ NIC: nicID,
+ },
+ })
+
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.ReadableEvents)
+ ep, err := s.NewRawEndpoint(udp.ProtocolNumber, test.proto, &wq, true /* associated */)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err)
+ }
+ defer ep.Close()
+
+ writeOpts := tcpip.WriteOptions{
+ To: &tcpip.FullAddress{
+ Addr: test.addr.Address,
+ },
+ }
+ data := []byte{1, 2, 3, 4}
+ var r bytes.Reader
+ r.Reset(data)
+ if n, err := ep.Write(&r, writeOpts); err != nil {
+ t.Fatalf("ep.Write(_, _): %s", err)
+ } else if want := int64(len(data)); n != want {
+ t.Fatalf("got ep.Write(_, _) = (%d, nil), want = (%d, nil)", n, want)
+ }
+
+ // Wait for the endpoint to become readable.
+ <-ch
+
+ var w bytes.Buffer
+ rr, err := ep.Read(&w, tcpip.ReadOptions{
+ NeedRemoteAddr: true,
+ })
+ if err != nil {
+ t.Fatalf("ep.Read(...): %s", err)
+ }
+ if diff := cmp.Diff(data, w.Bytes()[test.payloadOffset:]); diff != "" {
+ t.Errorf("payload mismatch (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff(tcpip.FullAddress{Addr: test.addr.Address, NIC: nicID}, rr.RemoteAddr); diff != "" {
+ t.Errorf("remote addr mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 2aa38eb98..1c3b0887f 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -167,14 +167,17 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet
p := hdr.TransportProtocol()
dstAddr := hdr.DestinationAddress()
// Skip the ip header, then deliver the error.
- pkt.Data().DeleteFront(hlen)
+ if _, ok := pkt.Data().Consume(hlen); !ok {
+ panic(fmt.Sprintf("could not consume the IP header of %d bytes", hlen))
+ }
e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, errInfo, pkt)
}
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
received := e.stats.icmp.packetsReceived
// ICMP packets don't have their TransportHeader fields set. See
- // icmp/protocol.go:protocol.Parse for a full explanation.
+ // icmp/protocol.go:protocol.Parse for a full explanation. Not all ICMP types
+ // require consuming the header, so we only call PullUp.
v, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize)
if !ok {
received.invalid.Increment()
@@ -240,15 +243,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
case header.ICMPv4Echo:
received.echoRequest.Increment()
- sent := e.stats.icmp.packetsSent
- if !e.protocol.stack.AllowICMPMessage() {
- sent.rateLimited.Increment()
- return
- }
-
// DeliverTransportPacket will take ownership of pkt so don't use it beyond
// this point. Make a deep copy of the data before pkt gets sent as we will
- // be modifying fields.
+ // be modifying fields. Both the ICMP header (with its type modified to
+ // EchoReply) and payload are reused in the reply packet.
//
// TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no
// waiting endpoints. Consider moving responsibility for doing the copy to
@@ -281,6 +279,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
}
defer r.Release()
+ sent := e.stats.icmp.packetsSent
+ if !e.protocol.allowICMPReply(header.ICMPv4EchoReply, header.ICMPv4UnusedCode) {
+ sent.rateLimited.Increment()
+ return
+ }
+
// TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the
// header information, we may have to change this code to handle the
// ICMP header no longer being in the data buffer.
@@ -331,6 +335,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
case header.ICMPv4EchoReply:
received.echoReply.Increment()
+ // ICMP sockets expect the ICMP header to be present, so we don't consume
+ // the ICMP header.
e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt)
case header.ICMPv4DstUnreachable:
@@ -338,7 +344,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
mtu := h.MTU()
code := h.Code()
- pkt.Data().DeleteFront(header.ICMPv4MinimumSize)
+ if _, ok := pkt.Data().Consume(header.ICMPv4MinimumSize); !ok {
+ panic("could not consume ICMPv4MinimumSize bytes")
+ }
switch code {
case header.ICMPv4HostUnreachable:
e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt)
@@ -562,13 +570,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
return &tcpip.ErrNotConnected{}
}
- sent := netEP.stats.icmp.packetsSent
-
- if !p.stack.AllowICMPMessage() {
- sent.rateLimited.Increment()
- return nil
- }
-
transportHeader := pkt.TransportHeader().View()
// Don't respond to icmp error packets.
@@ -606,6 +607,35 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
}
}
+ sent := netEP.stats.icmp.packetsSent
+ icmpType, icmpCode, counter, pointer := func() (header.ICMPv4Type, header.ICMPv4Code, tcpip.MultiCounterStat, byte) {
+ switch reason := reason.(type) {
+ case *icmpReasonPortUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4PortUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonProtoUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4ProtoUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonNetworkUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4NetUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonHostUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonFragmentationNeeded:
+ return header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, sent.dstUnreachable, 0
+ case *icmpReasonTTLExceeded:
+ return header.ICMPv4TimeExceeded, header.ICMPv4TTLExceeded, sent.timeExceeded, 0
+ case *icmpReasonReassemblyTimeout:
+ return header.ICMPv4TimeExceeded, header.ICMPv4ReassemblyTimeout, sent.timeExceeded, 0
+ case *icmpReasonParamProblem:
+ return header.ICMPv4ParamProblem, header.ICMPv4UnusedCode, sent.paramProblem, reason.pointer
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
+ }
+ }()
+
+ if !p.allowICMPReply(icmpType, icmpCode) {
+ sent.rateLimited.Increment()
+ return nil
+ }
+
// Now work out how much of the triggering packet we should return.
// As per RFC 1812 Section 4.3.2.3
//
@@ -658,44 +688,9 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
- var counter tcpip.MultiCounterStat
- switch reason := reason.(type) {
- case *icmpReasonPortUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4PortUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonProtoUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonNetworkUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4NetUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonHostUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4HostUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonFragmentationNeeded:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4FragmentationNeeded)
- counter = sent.dstUnreachable
- case *icmpReasonTTLExceeded:
- icmpHdr.SetType(header.ICMPv4TimeExceeded)
- icmpHdr.SetCode(header.ICMPv4TTLExceeded)
- counter = sent.timeExceeded
- case *icmpReasonReassemblyTimeout:
- icmpHdr.SetType(header.ICMPv4TimeExceeded)
- icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout)
- counter = sent.timeExceeded
- case *icmpReasonParamProblem:
- icmpHdr.SetType(header.ICMPv4ParamProblem)
- icmpHdr.SetCode(header.ICMPv4UnusedCode)
- icmpHdr.SetPointer(reason.pointer)
- counter = sent.paramProblem
- default:
- panic(fmt.Sprintf("unsupported ICMP type %T", reason))
- }
+ icmpHdr.SetCode(icmpCode)
+ icmpHdr.SetType(icmpType)
+ icmpHdr.SetPointer(pointer)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data().AsRange().Checksum()))
if err := route.WritePacket(
diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go
index 4bd6f462e..c6576fcbc 100644
--- a/pkg/tcpip/network/ipv4/igmp_test.go
+++ b/pkg/tcpip/network/ipv4/igmp_test.go
@@ -120,9 +120,12 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma
// cycles.
func TestIGMPV1Present(t *testing.T) {
e, s, clock := createStack(t, true)
- addr := tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength}
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength},
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
@@ -215,8 +218,15 @@ func TestSendQueuedIGMPReports(t *testing.T) {
// The initial set of IGMP reports that were queued should be sent once an
// address is assigned.
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, stackAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: stackAddr,
+ PrefixLen: defaultPrefixLength,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if got := reportStat.Value(); got != 1 {
t.Errorf("got reportStat.Value() = %d, want = 1", got)
@@ -350,8 +360,12 @@ func TestIGMPPacketValidation(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
e, s, _ := createStack(t, true)
for _, address := range test.stackAddresses {
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, address); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, address, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: address,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
stats := s.Stats()
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 44c85bdb8..9b71738ae 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -167,6 +167,13 @@ func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint {
return nil
}
+func (p *protocol) getEndpointForNIC(id tcpip.NICID) (*endpoint, bool) {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ ep, ok := p.mu.eps[id]
+ return ep, ok
+}
+
func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -240,7 +247,7 @@ func (e *endpoint) Enable() tcpip.Error {
}
// Create an endpoint to receive broadcast packets on this interface.
- ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */)
+ ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.AddressProperties{PEB: stack.NeverPrimaryEndpoint})
if err != nil {
return err
}
@@ -419,7 +426,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// iptables filtering. All packets that reach here are locally
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesOutputDropped.Increment()
return nil
@@ -459,7 +466,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn
// Postrouting NAT can only change the source address, and does not alter the
// route or outgoing interface of the packet.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, e, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesPostroutingDropped.Increment()
return nil
@@ -542,7 +549,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
- outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName)
+ outputDropped, natPkts := e.protocol.stack.IPTables().CheckOutputPackets(pkts, r, outNicName)
stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped)))
for pkt := range outputDropped {
pkts.Remove(pkt)
@@ -569,7 +576,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
// We ignore the list of NAT-ed packets here because Postrouting NAT can only
// change the source address, and does not alter the route or outgoing
// interface of the packet.
- postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName)
+ postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, e, outNicName)
stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped)))
for pkt := range postroutingDropped {
pkts.Remove(pkt)
@@ -710,7 +717,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
inNicName := stk.FindNICNameFromID(e.nic.ID())
outNicName := stk.FindNICNameFromID(ep.nic.ID())
- if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesForwardDropped.Increment()
return nil
@@ -737,7 +744,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
inNicName := stk.FindNICNameFromID(e.nic.ID())
outNicName := stk.FindNICNameFromID(r.NICID())
- if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesForwardDropped.Increment()
return nil
@@ -746,7 +753,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
// We need to do a deep copy of the IP packet because
// WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
// not own it.
- newHdr := header.IPv4(stack.PayloadSince(pkt.NetworkHeader()))
+ newPkt := pkt.DeepCopyForForwarding(int(r.MaxHeaderLength()))
+ newHdr := header.IPv4(newPkt.NetworkHeader().View())
// As per RFC 791 page 30, Time to Live,
//
@@ -755,12 +763,19 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
// Even if no local information is available on the time actually
// spent, the field must be decremented by 1.
newHdr.SetTTL(ttl - 1)
+ // We perform a full checksum as we may have updated options above. The IP
+ // header is relatively small so this is not expected to be an expensive
+ // operation.
+ newHdr.SetChecksum(0)
+ newHdr.SetChecksum(^newHdr.CalculateChecksum())
- switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: buffer.View(newHdr).ToVectorisedView(),
- IsForwardedPacket: true,
- })); err.(type) {
+ forwardToEp, ok := e.protocol.getEndpointForNIC(r.NICID())
+ if !ok {
+ // The interface was removed after we obtained the route.
+ return &ip.ErrOther{Err: &tcpip.ErrUnknownDevice{}}
+ }
+
+ switch err := forwardToEp.writePacket(r, newPkt, true /* headerIncluded */); err.(type) {
case nil:
return nil
case *tcpip.ErrMessageTooLong:
@@ -826,7 +841,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Loopback traffic skips the prerouting chain.
inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesPreroutingDropped.Increment()
return
@@ -856,6 +871,8 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
}
func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, inNICName string) {
+ pkt.NICID = e.nic.ID()
+
// Raw socket packets are delivered based solely on the transport protocol
// number. We only require that the packet be valid IPv4, and that they not
// be fragmented.
@@ -863,7 +880,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer,
e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt)
}
- pkt.NICID = e.nic.ID()
stats := e.stats
stats.ip.ValidPacketsReceived.Increment()
@@ -924,7 +940,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer,
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok {
// iptables is telling us to drop the packet.
stats.ip.IPTablesInputDropped.Increment()
return
@@ -1074,11 +1090,11 @@ func (e *endpoint) Close() {
}
// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+ ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, properties)
if err == nil {
e.mu.igmp.sendQueuedReports()
}
@@ -1199,6 +1215,9 @@ type protocol struct {
// eps is keyed by NICID to allow protocol methods to retrieve an endpoint
// when handling a packet, by looking at which NIC handled the packet.
eps map[tcpip.NICID]*endpoint
+
+ // ICMP types for which the stack's global rate limiting must apply.
+ icmpRateLimitedTypes map[header.ICMPv4Type]struct{}
}
// defaultTTL is the current default TTL for the protocol. Only the
@@ -1225,11 +1244,6 @@ func (p *protocol) MinimumPacketSize() int {
return header.IPv4MinimumSize
}
-// DefaultPrefixLen returns the IPv4 default prefix length.
-func (p *protocol) DefaultPrefixLen() int {
- return header.IPv4AddressSize * 8
-}
-
// ParseAddresses implements stack.NetworkProtocol.
func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv4(v)
@@ -1319,6 +1333,23 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true
}
+// allowICMPReply reports whether an ICMP reply with provided type and code may
+// be sent following the rate mask options and global ICMP rate limiter.
+func (p *protocol) allowICMPReply(icmpType header.ICMPv4Type, code header.ICMPv4Code) bool {
+ // Mimic linux and never rate limit for PMTU discovery.
+ // https://github.com/torvalds/linux/blob/9e9fb7655ed585da8f468e29221f0ba194a5f613/net/ipv4/icmp.c#L288
+ if icmpType == header.ICMPv4DstUnreachable && code == header.ICMPv4FragmentationNeeded {
+ return true
+ }
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+
+ if _, ok := p.mu.icmpRateLimitedTypes[icmpType]; ok {
+ return p.stack.AllowICMPMessage()
+ }
+ return true
+}
+
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload mtu.
func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) {
@@ -1398,6 +1429,14 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
}
p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
p.mu.eps = make(map[tcpip.NICID]*endpoint)
+ // Set ICMP rate limiting to Linux defaults.
+ // See https://man7.org/linux/man-pages/man7/icmp.7.html.
+ p.mu.icmpRateLimitedTypes = map[header.ICMPv4Type]struct{}{
+ header.ICMPv4DstUnreachable: struct{}{},
+ header.ICMPv4SrcQuench: struct{}{},
+ header.ICMPv4TimeExceeded: struct{}{},
+ header.ICMPv4ParamProblem: struct{}{},
+ }
return p
}
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 73407be67..ef91245d7 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -101,8 +101,12 @@ func TestExcludeBroadcast(t *testing.T) {
defer ep.Close()
// Add a valid primary endpoint address, now we can connect.
- if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address("\x0a\x00\x00\x02").WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
if err := ep.Connect(randomAddr); err != nil {
t.Errorf("Connect failed: %v", err)
@@ -356,8 +360,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
}
incomingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: incomingIPv4Addr}
- if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingIPv4ProtoAddr, err)
}
expectedEmittedPacketCount := 1
@@ -369,8 +373,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
}
outgoingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: outgoingIPv4Addr}
- if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingIPv4ProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -1184,8 +1188,8 @@ func TestIPv4Sanity(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
- if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err)
}
// Default routes for IPv4 so ICMP can find a route to the remote
@@ -1745,8 +1749,8 @@ func TestInvalidFragments(t *testing.T) {
const (
nicID = 1
linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- addr1 = "\x0a\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x02")
tos = 0
ident = 1
ttl = 48
@@ -2012,8 +2016,12 @@ func TestInvalidFragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
for _, f := range test.fragments {
@@ -2061,8 +2069,8 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
const (
nicID = 1
linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- addr1 = "\x0a\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x02")
tos = 0
ident = 1
ttl = 48
@@ -2237,8 +2245,12 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
@@ -2308,9 +2320,9 @@ func TestReceiveFragments(t *testing.T) {
const (
nicID = 1
- addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1
- addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2
- addr3 = "\x0c\xa8\x00\x03" // 192.168.0.3
+ addr1 = tcpip.Address("\x0c\xa8\x00\x01") // 192.168.0.1
+ addr2 = tcpip.Address("\x0c\xa8\x00\x02") // 192.168.0.2
+ addr3 = tcpip.Address("\x0c\xa8\x00\x03") // 192.168.0.3
)
// Build and return a UDP header containing payload.
@@ -2703,8 +2715,12 @@ func TestReceiveFragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
wq := waiter.Queue{}
@@ -2985,11 +3001,15 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
t.Fatalf("CreateNIC(1, _) failed: %s", err)
}
const (
- src = "\x10\x00\x00\x01"
- dst = "\x10\x00\x00\x02"
+ src = tcpip.Address("\x10\x00\x00\x01")
+ dst = tcpip.Address("\x10\x00\x00\x02")
)
- if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: src.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
mask := tcpip.AddressMask(header.IPv4Broadcast)
@@ -3161,8 +3181,8 @@ func TestPacketQueuing(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err)
+ if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -3285,8 +3305,12 @@ func TestCloseLocking(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddAddress(nicID1, ipv4.ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) failed: %s", nicID1, ipv4.ProtocolNumber, src, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: src.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
@@ -3349,3 +3373,139 @@ func TestCloseLocking(t *testing.T) {
}
}()
}
+
+func TestIcmpRateLimit(t *testing.T) {
+ var (
+ host1IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
+ PrefixLen: 24,
+ },
+ }
+ host2IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
+ PrefixLen: 24,
+ },
+ }
+ )
+ const icmpBurst = 5
+ e := channel.New(1, defaultMTU, tcpip.LinkAddress(""))
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: faketime.NewManualClock(),
+ })
+ s.SetICMPBurst(icmpBurst)
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: nicID,
+ },
+ })
+ tests := []struct {
+ name string
+ createPacket func() buffer.View
+ check func(*testing.T, *channel.Endpoint, int)
+ }{
+ {
+ name: "echo",
+ createPacket: func() buffer.View {
+ totalLength := header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ hdr := buffer.NewPrependable(totalLength)
+ icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ icmpH.SetIdent(1)
+ icmpH.SetSequence(1)
+ icmpH.SetType(header.ICMPv4Echo)
+ icmpH.SetCode(header.ICMPv4UnusedCode)
+ icmpH.SetChecksum(0)
+ icmpH.SetChecksum(^header.Checksum(icmpH, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(totalLength),
+ Protocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: 1,
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+ return hdr.View()
+ },
+ check: func(t *testing.T, e *channel.Endpoint, round int) {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected echo response, no packet read in endpoint in round %d", round)
+ }
+ if got, want := p.Proto, header.IPv4ProtocolNumber; got != want {
+ t.Errorf("got p.Proto = %d, want = %d", got, want)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4EchoReply),
+ ))
+ },
+ },
+ {
+ name: "dst unreachable",
+ createPacket: func() buffer.View {
+ totalLength := header.IPv4MinimumSize + header.UDPMinimumSize
+ hdr := buffer.NewPrependable(totalLength)
+ udpH := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ udpH.Encode(&header.UDPFields{
+ SrcPort: 100,
+ DstPort: 101,
+ Length: header.UDPMinimumSize,
+ })
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(totalLength),
+ Protocol: uint8(header.UDPProtocolNumber),
+ TTL: 1,
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+ return hdr.View()
+ },
+ check: func(t *testing.T, e *channel.Endpoint, round int) {
+ p, ok := e.Read()
+ if round >= icmpBurst {
+ if ok {
+ t.Errorf("got packet %x in round %d, expected ICMP rate limit to stop it", p.Pkt.Data().Views(), round)
+ }
+ return
+ }
+ if !ok {
+ t.Fatalf("expected unreachable in round %d, no packet read in endpoint", round)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ ))
+ },
+ },
+ }
+ for _, testCase := range tests {
+ t.Run(testCase.name, func(t *testing.T) {
+ for round := 0; round < icmpBurst+1; round++ {
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: testCase.createPacket().ToVectorisedView(),
+ }))
+ testCase.check(t, e, round)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index f99cbf8f3..f814926a3 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -51,6 +51,7 @@ go_test(
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
+ "@org_golang_x_time//rate:go_default_library",
],
)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 94caaae6c..ff23d48e7 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -187,7 +187,9 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
// Skip the IP header, then handle the fragmentation header if there
// is one.
- pkt.Data().DeleteFront(header.IPv6MinimumSize)
+ if _, ok := pkt.Data().Consume(header.IPv6MinimumSize); !ok {
+ panic("could not consume IPv6MinimumSize bytes")
+ }
if p == header.IPv6FragmentHeader {
f, ok := pkt.Data().PullUp(header.IPv6FragmentHeaderSize)
if !ok {
@@ -203,7 +205,9 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
// Skip fragmentation header and find out the actual protocol
// number.
- pkt.Data().DeleteFront(header.IPv6FragmentHeaderSize)
+ if _, ok := pkt.Data().Consume(header.IPv6FragmentHeaderSize); !ok {
+ panic("could not consume IPv6FragmentHeaderSize bytes")
+ }
}
e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, transErr, pkt)
@@ -325,7 +329,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
switch icmpType := h.Type(); icmpType {
case header.ICMPv6PacketTooBig:
received.packetTooBig.Increment()
- hdr, ok := pkt.Data().PullUp(header.ICMPv6PacketTooBigMinimumSize)
+ hdr, ok := pkt.Data().Consume(header.ICMPv6PacketTooBigMinimumSize)
if !ok {
received.invalid.Increment()
return
@@ -334,18 +338,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
if err != nil {
networkMTU = 0
}
- pkt.Data().DeleteFront(header.ICMPv6PacketTooBigMinimumSize)
e.handleControl(&icmpv6PacketTooBigSockError{mtu: networkMTU}, pkt)
case header.ICMPv6DstUnreachable:
received.dstUnreachable.Increment()
- hdr, ok := pkt.Data().PullUp(header.ICMPv6DstUnreachableMinimumSize)
+ hdr, ok := pkt.Data().Consume(header.ICMPv6DstUnreachableMinimumSize)
if !ok {
received.invalid.Increment()
return
}
code := header.ICMPv6(hdr).Code()
- pkt.Data().DeleteFront(header.ICMPv6DstUnreachableMinimumSize)
switch code {
case header.ICMPv6NetworkUnreachable:
e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt)
@@ -692,6 +694,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
}
defer r.Release()
+ if !e.protocol.allowICMPReply(header.ICMPv6EchoReply) {
+ sent.rateLimited.Increment()
+ return
+ }
+
replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize,
Data: pkt.Data().ExtractVV(),
@@ -1174,13 +1181,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
return &tcpip.ErrNotConnected{}
}
- sent := netEP.stats.icmp.packetsSent
-
- if !p.stack.AllowICMPMessage() {
- sent.rateLimited.Increment()
- return nil
- }
-
if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber {
// TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored.
// Unfortunately at this time ICMP Packets do not have a transport
@@ -1198,6 +1198,33 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
}
}
+ sent := netEP.stats.icmp.packetsSent
+ icmpType, icmpCode, counter, typeSpecific := func() (header.ICMPv6Type, header.ICMPv6Code, tcpip.MultiCounterStat, uint32) {
+ switch reason := reason.(type) {
+ case *icmpReasonParameterProblem:
+ return header.ICMPv6ParamProblem, reason.code, sent.paramProblem, reason.pointer
+ case *icmpReasonPortUnreachable:
+ return header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonNetUnreachable:
+ return header.ICMPv6DstUnreachable, header.ICMPv6NetworkUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonHostUnreachable:
+ return header.ICMPv6DstUnreachable, header.ICMPv6AddressUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonPacketTooBig:
+ return header.ICMPv6PacketTooBig, header.ICMPv6UnusedCode, sent.packetTooBig, 0
+ case *icmpReasonHopLimitExceeded:
+ return header.ICMPv6TimeExceeded, header.ICMPv6HopLimitExceeded, sent.timeExceeded, 0
+ case *icmpReasonReassemblyTimeout:
+ return header.ICMPv6TimeExceeded, header.ICMPv6ReassemblyTimeout, sent.timeExceeded, 0
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
+ }
+ }()
+
+ if !p.allowICMPReply(icmpType) {
+ sent.rateLimited.Increment()
+ return nil
+ }
+
network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
// As per RFC 4443 section 2.4
@@ -1232,40 +1259,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize))
- var counter tcpip.MultiCounterStat
- switch reason := reason.(type) {
- case *icmpReasonParameterProblem:
- icmpHdr.SetType(header.ICMPv6ParamProblem)
- icmpHdr.SetCode(reason.code)
- icmpHdr.SetTypeSpecific(reason.pointer)
- counter = sent.paramProblem
- case *icmpReasonPortUnreachable:
- icmpHdr.SetType(header.ICMPv6DstUnreachable)
- icmpHdr.SetCode(header.ICMPv6PortUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonNetUnreachable:
- icmpHdr.SetType(header.ICMPv6DstUnreachable)
- icmpHdr.SetCode(header.ICMPv6NetworkUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonHostUnreachable:
- icmpHdr.SetType(header.ICMPv6DstUnreachable)
- icmpHdr.SetCode(header.ICMPv6AddressUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonPacketTooBig:
- icmpHdr.SetType(header.ICMPv6PacketTooBig)
- icmpHdr.SetCode(header.ICMPv6UnusedCode)
- counter = sent.packetTooBig
- case *icmpReasonHopLimitExceeded:
- icmpHdr.SetType(header.ICMPv6TimeExceeded)
- icmpHdr.SetCode(header.ICMPv6HopLimitExceeded)
- counter = sent.timeExceeded
- case *icmpReasonReassemblyTimeout:
- icmpHdr.SetType(header.ICMPv6TimeExceeded)
- icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout)
- counter = sent.timeExceeded
- default:
- panic(fmt.Sprintf("unsupported ICMP type %T", reason))
- }
+ icmpHdr.SetType(icmpType)
+ icmpHdr.SetCode(icmpCode)
+ icmpHdr.SetTypeSpecific(typeSpecific)
+
dataRange := newPkt.Data().AsRange()
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 7c2a3e56b..03d9f425c 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -22,6 +22,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
+ "golang.org/x/time/rate"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -225,8 +226,8 @@ func TestICMPCounts(t *testing.T) {
t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
}
addr := lladdr0.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
@@ -407,8 +408,12 @@ func newTestContext(t *testing.T) *testContext {
if err := c.s0.CreateNIC(nicID, wrappedEP0); err != nil {
t.Fatalf("CreateNIC s0: %v", err)
}
- if err := c.s0.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress lladdr0: %v", err)
+ llProtocolAddr0 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := c.s0.AddProtocolAddress(nicID, llProtocolAddr0, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, llProtocolAddr0, err)
}
c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1)
@@ -416,8 +421,12 @@ func newTestContext(t *testing.T) *testContext {
if err := c.s1.CreateNIC(nicID, wrappedEP1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := c.s1.AddAddress(nicID, ProtocolNumber, lladdr1); err != nil {
- t.Fatalf("AddAddress lladdr1: %v", err)
+ llProtocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr1.WithPrefix(),
+ }
+ if err := c.s1.AddProtocolAddress(nicID, llProtocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, llProtocolAddr1, err)
}
subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -690,8 +699,12 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -883,8 +896,12 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -1065,8 +1082,12 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -1240,8 +1261,12 @@ func TestLinkAddressRequest(t *testing.T) {
}
if len(test.nicAddr) != 0 {
- if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: test.nicAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
@@ -1411,12 +1436,14 @@ func TestPacketQueing(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
Clock: clock,
})
+ // Make sure ICMP rate limiting doesn't get in our way.
+ s.SetICMPLimit(rate.Inf)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err)
+ if err := s.AddProtocolAddress(nicID, host1IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv6Addr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -1669,8 +1696,12 @@ func TestCallsToNeighborCache(t *testing.T) {
if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
{
@@ -1704,8 +1735,8 @@ func TestCallsToNeighborCache(t *testing.T) {
t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
}
addr := lladdr0.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index b1aec5312..600e805f8 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -748,7 +748,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// iptables filtering. All packets that reach here are locally
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesOutputDropped.Increment()
return nil
@@ -788,7 +788,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol
// Postrouting NAT can only change the source address, and does not alter the
// route or outgoing interface of the packet.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, e, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesPostroutingDropped.Increment()
return nil
@@ -871,7 +871,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
// iptables filtering. All packets that reach here are locally
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName)
+ outputDropped, natPkts := e.protocol.stack.IPTables().CheckOutputPackets(pkts, r, outNicName)
stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped)))
for pkt := range outputDropped {
pkts.Remove(pkt)
@@ -897,7 +897,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
// We ignore the list of NAT-ed packets here because Postrouting NAT can only
// change the source address, and does not alter the route or outgoing
// interface of the packet.
- postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName)
+ postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, e, outNicName)
stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped)))
for pkt := range postroutingDropped {
pkts.Remove(pkt)
@@ -984,7 +984,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
inNicName := stk.FindNICNameFromID(e.nic.ID())
outNicName := stk.FindNICNameFromID(ep.nic.ID())
- if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesForwardDropped.Increment()
return nil
@@ -1015,7 +1015,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
inNicName := stk.FindNICNameFromID(e.nic.ID())
outNicName := stk.FindNICNameFromID(r.NICID())
- if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesForwardDropped.Increment()
return nil
@@ -1024,7 +1024,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
// We need to do a deep copy of the IP packet because
// WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
// not own it.
- newHdr := header.IPv6(stack.PayloadSince(pkt.NetworkHeader()))
+ newPkt := pkt.DeepCopyForForwarding(int(r.MaxHeaderLength()))
+ newHdr := header.IPv6(newPkt.NetworkHeader().View())
// As per RFC 8200 section 3,
//
@@ -1032,11 +1033,13 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
// each node that forwards the packet.
newHdr.SetHopLimit(hopLimit - 1)
- switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: buffer.View(newHdr).ToVectorisedView(),
- IsForwardedPacket: true,
- })); err.(type) {
+ forwardToEp, ok := e.protocol.getEndpointForNIC(r.NICID())
+ if !ok {
+ // The interface was removed after we obtained the route.
+ return &ip.ErrOther{Err: &tcpip.ErrUnknownDevice{}}
+ }
+
+ switch err := forwardToEp.writePacket(r, newPkt, newPkt.TransportProtocolNumber, true /* headerIncluded */); err.(type) {
case nil:
return nil
case *tcpip.ErrMessageTooLong:
@@ -1097,7 +1100,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Loopback traffic skips the prerouting chain.
inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesPreroutingDropped.Increment()
return
@@ -1127,11 +1130,12 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
}
func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer, inNICName string) {
+ pkt.NICID = e.nic.ID()
+
// Raw socket packets are delivered based solely on the transport protocol
// number. We only require that the packet be valid IPv6.
e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt)
- pkt.NICID = e.nic.ID()
stats := e.stats.ip
stats.ValidPacketsReceived.Increment()
@@ -1179,7 +1183,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer,
// iptables filtering. All packets that reach here are intended for
// this machine and need not be forwarded.
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesInputDropped.Increment()
return
@@ -1533,19 +1537,22 @@ func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffe
// If the last header in the payload isn't a known IPv6 extension header,
// handle it as if it is transport layer data.
- // Calculate the number of octets parsed from data. We want to remove all
- // the data except the unparsed portion located at the end, which its size
- // is extHdr.Buf.Size().
+ // Calculate the number of octets parsed from data. We want to consume all
+ // the data except the unparsed portion located at the end, whose size is
+ // extHdr.Buf.Size().
trim := pkt.Data().Size() - extHdr.Buf.Size()
// For unfragmented packets, extHdr still contains the transport header.
- // Get rid of it.
+ // Consume that too.
//
// For reassembled fragments, pkt.TransportHeader is unset, so this is a
// no-op and pkt.Data begins with the transport header.
trim += pkt.TransportHeader().View().Size()
- pkt.Data().DeleteFront(trim)
+ if _, ok := pkt.Data().Consume(trim); !ok {
+ stats.MalformedPacketsReceived.Increment()
+ return fmt.Errorf("could not consume %d bytes", trim)
+ }
stats.PacketsDelivered.Increment()
if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
@@ -1627,12 +1634,12 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
}
// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) {
// TODO(b/169350103): add checks here after making sure we no longer receive
// an empty address.
e.mu.Lock()
defer e.mu.Unlock()
- return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated)
+ return e.addAndAcquirePermanentAddressLocked(addr, properties)
}
// addAndAcquirePermanentAddressLocked is like AddAndAcquirePermanentAddress but
@@ -1642,8 +1649,8 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p
// solicited-node multicast group and start duplicate address detection.
//
// Precondition: e.mu must be write locked.
-func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
- addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) {
+ addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, properties)
if err != nil {
return nil, err
}
@@ -1986,6 +1993,9 @@ type protocol struct {
// eps is keyed by NICID to allow protocol methods to retrieve an endpoint
// when handling a packet, by looking at which NIC handled the packet.
eps map[tcpip.NICID]*endpoint
+
+ // ICMP types for which the stack's global rate limiting must apply.
+ icmpRateLimitedTypes map[header.ICMPv6Type]struct{}
}
ids []uint32
@@ -1997,7 +2007,8 @@ type protocol struct {
// Must be accessed using atomic operations.
defaultTTL uint32
- fragmentation *fragmentation.Fragmentation
+ fragmentation *fragmentation.Fragmentation
+ icmpRateLimiter *stack.ICMPRateLimiter
}
// Number returns the ipv6 protocol number.
@@ -2010,11 +2021,6 @@ func (p *protocol) MinimumPacketSize() int {
return header.IPv6MinimumSize
}
-// DefaultPrefixLen returns the IPv6 default prefix length.
-func (p *protocol) DefaultPrefixLen() int {
- return header.IPv6AddressSize * 8
-}
-
// ParseAddresses implements stack.NetworkProtocol.
func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv6(v)
@@ -2086,6 +2092,13 @@ func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint {
return nil
}
+func (p *protocol) getEndpointForNIC(id tcpip.NICID) (*endpoint, bool) {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ ep, ok := p.mu.eps[id]
+ return ep, ok
+}
+
func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -2171,6 +2184,18 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return proto, !fragMore && fragOffset == 0, true
}
+// allowICMPReply reports whether an ICMP reply with provided type may
+// be sent following the rate mask options and global ICMP rate limiter.
+func (p *protocol) allowICMPReply(icmpType header.ICMPv6Type) bool {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+
+ if _, ok := p.mu.icmpRateLimitedTypes[icmpType]; ok {
+ return p.stack.AllowICMPMessage()
+ }
+ return true
+}
+
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload MTU and the length of every IPv6 header.
// Note that this is different than the Payload Length field of the IPv6 header,
@@ -2267,6 +2292,21 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
p.mu.eps = make(map[tcpip.NICID]*endpoint)
p.SetDefaultTTL(DefaultTTL)
+ // Set default ICMP rate limiting to Linux defaults.
+ //
+ // Default: 0-1,3-127 (rate limit ICMPv6 errors except Packet Too Big)
+ // See https://www.kernel.org/doc/Documentation/networking/ip-sysctl.txt.
+ defaultIcmpTypes := make(map[header.ICMPv6Type]struct{})
+ for i := header.ICMPv6Type(0); i < header.ICMPv6EchoRequest; i++ {
+ switch i {
+ case header.ICMPv6PacketTooBig:
+ // Do not rate limit packet too big by default.
+ default:
+ defaultIcmpTypes[i] = struct{}{}
+ }
+ }
+ p.mu.icmpRateLimitedTypes = defaultIcmpTypes
+
return p
}
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index d2a23fd4f..e5286081e 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -41,12 +41,12 @@ import (
)
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
// The least significant 3 bytes are the same as addr2 so both addr2 and
// addr3 will have the same solicited-node address.
- addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02"
- addr4 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03"
+ addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02")
+ addr4 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03")
// Tests use the extension header identifier values as uint8 instead of
// header.IPv6ExtensionHeaderIdentifier.
@@ -298,16 +298,24 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
// addr2/addr3 yet as we haven't added those addresses.
test.rxf(t, s, e, addr1, snmc, 0)
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr2, err)
}
// Should receive a packet destined to the solicited node address of
// addr2/addr3 now that we have added added addr2.
test.rxf(t, s, e, addr1, snmc, 1)
- if err := s.AddAddress(nicID, ProtocolNumber, addr3); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr3, err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr3.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr3, err)
}
// Should still receive a packet destined to the solicited node address of
@@ -374,8 +382,12 @@ func TestAddIpv6Address(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, test.addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, nil) = %s", nicID, ProtocolNumber, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: test.addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if addr, err := s.GetMainNICAddress(nicID, ProtocolNumber); err != nil {
@@ -898,8 +910,12 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// Add a default route so that a return packet knows where to go.
@@ -1992,8 +2008,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
wq := waiter.Queue{}
@@ -2060,8 +2080,8 @@ func TestReceiveIPv6Fragments(t *testing.T) {
func TestInvalidIPv6Fragments(t *testing.T) {
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
nicID = 1
hoplimit = 255
@@ -2150,8 +2170,12 @@ func TestInvalidIPv6Fragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv6EmptySubnet,
@@ -2216,8 +2240,8 @@ func TestInvalidIPv6Fragments(t *testing.T) {
func TestFragmentReassemblyTimeout(t *testing.T) {
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
nicID = 1
hoplimit = 255
@@ -2402,8 +2426,12 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv6EmptySubnet,
@@ -2645,11 +2673,15 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
t.Fatalf("CreateNIC(1, _) failed: %s", err)
}
const (
- src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ src = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ dst = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
)
- if err := s.AddAddress(1, ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: src.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")
@@ -3297,8 +3329,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
}
incomingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: incomingIPv6Addr}
- if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv6ProtoAddr, err)
+ if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingIPv6ProtoAddr, err)
}
outgoingEndpoint := channel.New(1, header.IPv6MinimumMTU, "")
@@ -3306,8 +3338,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
}
outgoingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: outgoingIPv6Addr}
- if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv6ProtoAddr, err)
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingIPv6ProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -3341,7 +3373,8 @@ func TestForwarding(t *testing.T) {
ipHeaderLength := header.IPv6MinimumSize
icmpHeaderLength := header.ICMPv6MinimumSize
- totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength + extHdrLen
+ payloadLength := icmpHeaderLength + test.payloadLength + extHdrLen
+ totalLength := ipHeaderLength + payloadLength
hdr := buffer.NewPrependable(totalLength)
hdr.Prepend(test.payloadLength)
icmpH := header.ICMPv6(hdr.Prepend(icmpHeaderLength))
@@ -3359,7 +3392,7 @@ func TestForwarding(t *testing.T) {
copy(hdr.Prepend(extHdrLen), extHdrBytes)
ip := header.IPv6(hdr.Prepend(ipHeaderLength))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(header.ICMPv6MinimumSize + test.payloadLength),
+ PayloadLength: uint16(payloadLength),
TransportProtocol: transportProtocol,
HopLimit: test.TTL,
SrcAddr: test.sourceAddr,
@@ -3489,3 +3522,149 @@ func TestMultiCounterStatsInitialization(t *testing.T) {
t.Error(err)
}
}
+
+func TestIcmpRateLimit(t *testing.T) {
+ var (
+ host1IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10::1").To16()),
+ PrefixLen: 64,
+ },
+ }
+ host2IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10::2").To16()),
+ PrefixLen: 64,
+ },
+ }
+ )
+ const icmpBurst = 5
+ e := channel.New(1, defaultMTU, tcpip.LinkAddress(""))
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: faketime.NewManualClock(),
+ })
+ s.SetICMPBurst(icmpBurst)
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, host1IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv6Addr, err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: nicID,
+ },
+ })
+ tests := []struct {
+ name string
+ createPacket func() buffer.View
+ check func(*testing.T, *channel.Endpoint, int)
+ }{
+ {
+ name: "echo",
+ createPacket: func() buffer.View {
+ totalLength := header.IPv6MinimumSize + header.ICMPv6MinimumSize
+ hdr := buffer.NewPrependable(totalLength)
+ icmpH := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ icmpH.SetIdent(1)
+ icmpH.SetSequence(1)
+ icmpH.SetType(header.ICMPv6EchoRequest)
+ icmpH.SetCode(header.ICMPv6UnusedCode)
+ icmpH.SetChecksum(0)
+ icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmpH,
+ Src: host2IPv6Addr.AddressWithPrefix.Address,
+ Dst: host1IPv6Addr.AddressWithPrefix.Address,
+ }))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 1,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ return hdr.View()
+ },
+ check: func(t *testing.T, e *channel.Endpoint, round int) {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected echo response, no packet read in endpoint in round %d", round)
+ }
+ if got, want := p.Proto, header.IPv6ProtocolNumber; got != want {
+ t.Errorf("got p.Proto = %d, want = %d", got, want)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoReply),
+ ))
+ },
+ },
+ {
+ name: "dst unreachable",
+ createPacket: func() buffer.View {
+ totalLength := header.IPv6MinimumSize + header.UDPMinimumSize
+ hdr := buffer.NewPrependable(totalLength)
+ udpH := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ udpH.Encode(&header.UDPFields{
+ SrcPort: 100,
+ DstPort: 101,
+ Length: header.UDPMinimumSize,
+ })
+
+ // Calculate the UDP checksum and set it.
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
+ sum = header.Checksum(nil, sum)
+ udpH.SetChecksum(^udpH.CalculateChecksum(sum))
+
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: 1,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ return hdr.View()
+ },
+ check: func(t *testing.T, e *channel.Endpoint, round int) {
+ p, ok := e.Read()
+ if round >= icmpBurst {
+ if ok {
+ t.Errorf("got packet %x in round %d, expected ICMP rate limit to stop it", p.Pkt.Data().Views(), round)
+ }
+ return
+ }
+ if !ok {
+ t.Fatalf("expected unreachable in round %d, no packet read in endpoint", round)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ ))
+ },
+ },
+ }
+ for _, testCase := range tests {
+ t.Run(testCase.name, func(t *testing.T) {
+ for round := 0; round < icmpBurst+1; round++ {
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: testCase.createPacket().ToVectorisedView(),
+ }))
+ testCase.check(t, e, round)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
index bc9cf6999..3e5c438d3 100644
--- a/pkg/tcpip/network/ipv6/mld_test.go
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -75,8 +75,12 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) {
// The stack will join an address's solicited node multicast address when
// an address is added. An MLD report message should be sent for the
// solicited-node group.
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
@@ -216,8 +220,13 @@ func TestSendQueuedMLDReports(t *testing.T) {
// Note, we will still expect to send a report for the global address's
// solicited node address from the unspecified address as per RFC 3590
// section 4.
- if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err)
+ properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ globalProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: globalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, globalProtocolAddr, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, globalProtocolAddr, properties, err)
}
reportCounter++
if got := reportStat.Value(); got != reportCounter {
@@ -252,8 +261,12 @@ func TestSendQueuedMLDReports(t *testing.T) {
// Adding a link-local address should send a report for its solicited node
// address and globalMulticastAddr.
- if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err)
+ linkLocalProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, linkLocalProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, linkLocalProtocolAddr, err)
}
if dadResolutionTime != 0 {
reportCounter++
@@ -567,8 +580,12 @@ func TestMLDSkipProtocol(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 8837d66d8..938427420 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -1130,7 +1130,11 @@ func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, config
return nil
}
- addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated)
+ addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.AddressProperties{
+ PEB: stack.FirstPrimaryEndpoint,
+ ConfigType: configType,
+ Deprecated: deprecated,
+ })
if err != nil {
panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err))
}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index f0186c64e..8297a7e10 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -144,8 +144,12 @@ func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf)
@@ -406,8 +410,12 @@ func TestNeighborSolicitationResponse(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: nicAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -602,8 +610,12 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf)
@@ -831,8 +843,12 @@ func TestNDPValidation(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber)
@@ -962,8 +978,12 @@ func TestNeighborAdvertisementValidation(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ndpNASize := header.ICMPv6NeighborAdvertMinimumSize
@@ -1283,8 +1303,12 @@ func TestCheckDuplicateAddress(t *testing.T) {
checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[dadPacketsSent])}),
))
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
checkDADMsg()
diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go
index 1b96b1fb8..26640b7ee 100644
--- a/pkg/tcpip/network/multicast_group_test.go
+++ b/pkg/tcpip/network/multicast_group_test.go
@@ -151,15 +151,22 @@ func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.Link
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- addr := tcpip.AddressWithPrefix{
- Address: stackIPv4Addr,
- PrefixLen: defaultIPv4PrefixLength,
+ addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: stackIPv4Addr,
+ PrefixLen: defaultIPv4PrefixLength,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
}
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalIPv6Addr1.WithPrefix(),
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
return s, clock
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index 009cab643..05b879543 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -146,8 +146,12 @@ func main() {
log.Fatal(err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, addr); err != nil {
- log.Fatal(err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ log.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Add default route.
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index c10b19aa0..a72afadda 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -124,13 +124,13 @@ func main() {
log.Fatalf("Bad IP address: %v", addrName)
}
- var addr tcpip.Address
+ var addrWithPrefix tcpip.AddressWithPrefix
var proto tcpip.NetworkProtocolNumber
if parsedAddr.To4() != nil {
- addr = tcpip.Address(parsedAddr.To4())
+ addrWithPrefix = tcpip.Address(parsedAddr.To4()).WithPrefix()
proto = ipv4.ProtocolNumber
} else if parsedAddr.To16() != nil {
- addr = tcpip.Address(parsedAddr.To16())
+ addrWithPrefix = tcpip.Address(parsedAddr.To16()).WithPrefix()
proto = ipv6.ProtocolNumber
} else {
log.Fatalf("Unknown IP type: %v", addrName)
@@ -176,11 +176,15 @@ func main() {
log.Fatal(err)
}
- if err := s.AddAddress(1, proto, addr); err != nil {
- log.Fatal(err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: proto,
+ AddressWithPrefix: addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ log.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
- subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr))))
+ subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addrWithPrefix.Address))), tcpip.AddressMask(strings.Repeat("\x00", len(addrWithPrefix.Address))))
if err != nil {
log.Fatal(err)
}
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index 6bce3af04..b0b2d0afd 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -57,6 +57,11 @@ type SocketOptionsHandler interface {
// OnSetReceiveBufferSize is invoked by SO_RCVBUF and SO_RCVBUFFORCE.
OnSetReceiveBufferSize(v, oldSz int64) (newSz int64)
+
+ // WakeupWriters is invoked when the send buffer size for an endpoint is
+ // changed. The handler notifies the writers if the send buffer size is
+ // increased with setsockopt(2) for TCP endpoints.
+ WakeupWriters()
}
// DefaultSocketOptionsHandler is an embeddable type that implements no-op
@@ -98,6 +103,9 @@ func (*DefaultSocketOptionsHandler) OnSetSendBufferSize(v int64) (newSz int64) {
return v
}
+// WakeupWriters implements SocketOptionsHandler.WakeupWriters.
+func (*DefaultSocketOptionsHandler) WakeupWriters() {}
+
// OnSetReceiveBufferSize implements SocketOptionsHandler.OnSetReceiveBufferSize.
func (*DefaultSocketOptionsHandler) OnSetReceiveBufferSize(v, oldSz int64) (newSz int64) {
return v
@@ -162,10 +170,14 @@ type SocketOptions struct {
// message is passed with incoming packets.
receiveTClassEnabled uint32
- // receivePacketInfoEnabled is used to specify if more inforamtion is
- // provided with incoming packets such as interface index and address.
+ // receivePacketInfoEnabled is used to specify if more information is
+ // provided with incoming IPv4 packets.
receivePacketInfoEnabled uint32
+ // receivePacketInfoEnabled is used to specify if more information is
+ // provided with incoming IPv6 packets.
+ receiveIPv6PacketInfoEnabled uint32
+
// hdrIncludeEnabled is used to indicate for a raw endpoint that all packets
// being written have an IP header and the endpoint should not attach an IP
// header.
@@ -352,6 +364,16 @@ func (so *SocketOptions) SetReceivePacketInfo(v bool) {
storeAtomicBool(&so.receivePacketInfoEnabled, v)
}
+// GetIPv6ReceivePacketInfo gets value for IPV6_RECVPKTINFO option.
+func (so *SocketOptions) GetIPv6ReceivePacketInfo() bool {
+ return atomic.LoadUint32(&so.receiveIPv6PacketInfoEnabled) != 0
+}
+
+// SetIPv6ReceivePacketInfo sets value for IPV6_RECVPKTINFO option.
+func (so *SocketOptions) SetIPv6ReceivePacketInfo(v bool) {
+ storeAtomicBool(&so.receiveIPv6PacketInfoEnabled, v)
+}
+
// GetHeaderIncluded gets value for IP_HDRINCL option.
func (so *SocketOptions) GetHeaderIncluded() bool {
return atomic.LoadUint32(&so.hdrIncludedEnabled) != 0
@@ -626,6 +648,9 @@ func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) {
sendBufferSize = so.handler.OnSetSendBufferSize(sendBufferSize)
}
so.sendBufferSize.Store(sendBufferSize)
+ if notify {
+ so.handler.WakeupWriters()
+ }
}
// GetReceiveBufferSize gets value for SO_RCVBUF option.
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index e0847e58a..6c42ab29b 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -85,6 +85,7 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
+ "//pkg/tcpip/internal/tcp",
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/transport/tcpconntrack",
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index ae0bb4ace..7e4b5bf74 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -117,10 +117,10 @@ func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressS
}
// AddAndAcquirePermanentAddress implements AddressableEndpoint.
-func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) {
+func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error) {
a.mu.Lock()
defer a.mu.Unlock()
- ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */)
+ ep, err := a.addAndAcquireAddressLocked(addr, properties, true /* permanent */)
// From https://golang.org/doc/faq#nil_error:
//
// Under the covers, interfaces are implemented as two elements, a type T and
@@ -149,7 +149,7 @@ func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.Addr
func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, tcpip.Error) {
a.mu.Lock()
defer a.mu.Unlock()
- ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */)
+ ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: peb}, false /* permanent */)
// From https://golang.org/doc/faq#nil_error:
//
// Under the covers, interfaces are implemented as two elements, a type T and
@@ -180,7 +180,7 @@ func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.Addr
// returned, regardless the kind of address that is being added.
//
// Precondition: a.mu must be write locked.
-func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, tcpip.Error) {
+func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, properties AddressProperties, permanent bool) (*addressState, tcpip.Error) {
// attemptAddToPrimary is false when the address is already in the primary
// address list.
attemptAddToPrimary := true
@@ -208,7 +208,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
// We now promote the address.
for i, s := range a.mu.primary {
if s == addrState {
- switch peb {
+ switch properties.PEB {
case CanBePrimaryEndpoint:
// The address is already in the primary address list.
attemptAddToPrimary = false
@@ -222,7 +222,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
case NeverPrimaryEndpoint:
a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...)
default:
- panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb))
+ panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB))
}
break
}
@@ -262,11 +262,11 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
}
// Acquire the address before returning it.
addrState.mu.refs++
- addrState.mu.deprecated = deprecated
- addrState.mu.configType = configType
+ addrState.mu.deprecated = properties.Deprecated
+ addrState.mu.configType = properties.ConfigType
if attemptAddToPrimary {
- switch peb {
+ switch properties.PEB {
case NeverPrimaryEndpoint:
case CanBePrimaryEndpoint:
a.mu.primary = append(a.mu.primary, addrState)
@@ -285,7 +285,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
a.mu.primary[0] = addrState
}
default:
- panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb))
+ panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB))
}
}
@@ -489,12 +489,12 @@ func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tc
// Proceed to add a new temporary endpoint.
addr := localAddr.WithPrefix()
- ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */)
+ ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: tempPEB}, false /* permanent */)
if err != nil {
// addAndAcquireAddressLocked only returns an error if the address is
// already assigned but we just checked above if the address exists so we
// expect no error.
- panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err))
+ panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, AddressProperties{PEB: %s}, false): %s", addr, tempPEB, err))
}
// From https://golang.org/doc/faq#nil_error:
diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go
index 140f146f6..c55f85743 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state_test.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go
@@ -38,9 +38,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) {
}
{
- ep, err := s.AddAndAcquirePermanentAddress(addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */)
+ ep, err := s.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{PEB: stack.NeverPrimaryEndpoint})
if err != nil {
- t.Fatalf("s.AddAndAcquirePermanentAddress(%s, %d, %d, false): %s", addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, err)
+ t.Fatalf("s.AddAndAcquirePermanentAddress(%s, AddressProperties{PEB: NeverPrimaryEndpoint}): %s", addr, err)
}
// We don't need the address endpoint.
ep.DecRef()
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 068dab7ce..16d295271 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -64,13 +64,21 @@ type tuple struct {
// tupleEntry is used to build an intrusive list of tuples.
tupleEntry
- tupleID
-
// conn is the connection tracking entry this tuple belongs to.
conn *conn
// direction is the direction of the tuple.
direction direction
+
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
+ tupleID tupleID
+}
+
+func (t *tuple) id() tupleID {
+ t.mu.RLock()
+ defer t.mu.RUnlock()
+ return t.tupleID
}
// tupleID uniquely identifies a connection in one direction. It currently
@@ -103,50 +111,43 @@ func (ti tupleID) reply() tupleID {
//
// +stateify savable
type conn struct {
+ ct *ConnTrack
+
// original is the tuple in original direction. It is immutable.
original tuple
- // reply is the tuple in reply direction. It is immutable.
+ // reply is the tuple in reply direction.
reply tuple
- // manip indicates if the packet should be manipulated. It is immutable.
- // TODO(gvisor.dev/issue/5696): Support updating manipulation type.
+ mu sync.RWMutex `state:"nosave"`
+ // Indicates that the connection has been finalized and may handle replies.
+ //
+ // +checklocks:mu
+ finalized bool
+ // manip indicates if the packet should be manipulated.
+ //
+ // +checklocks:mu
manip manipType
-
- // tcbHook indicates if the packet is inbound or outbound to
- // update the state of tcb. It is immutable.
- tcbHook Hook
-
- // mu protects all mutable state.
- mu sync.Mutex `state:"nosave"`
// tcb is TCB control block. It is used to keep track of states
- // of tcp connection and is protected by mu.
+ // of tcp connection.
+ //
+ // +checklocks:mu
tcb tcpconntrack.TCB
// lastUsed is the last time the connection saw a relevant packet, and
- // is updated by each packet on the connection. It is protected by mu.
+ // is updated by each packet on the connection.
//
// TODO(gvisor.dev/issue/5939): do not use the ambient clock.
+ //
+ // +checklocks:mu
lastUsed time.Time `state:".(unixTime)"`
}
-// newConn creates new connection.
-func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
- conn := conn{
- manip: manip,
- tcbHook: hook,
- lastUsed: time.Now(),
- }
- conn.original = tuple{conn: &conn, tupleID: orig}
- conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
- return &conn
-}
-
// timedOut returns whether the connection timed out based on its state.
func (cn *conn) timedOut(now time.Time) bool {
const establishedTimeout = 5 * 24 * time.Hour
const defaultTimeout = 120 * time.Second
- cn.mu.Lock()
- defer cn.mu.Unlock()
+ cn.mu.RLock()
+ defer cn.mu.RUnlock()
if cn.tcb.State() == tcpconntrack.ResultAlive {
// Use the same default as Linux, which doesn't delete
// established connections for 5(!) days.
@@ -159,17 +160,30 @@ func (cn *conn) timedOut(now time.Time) bool {
// update the connection tracking state.
//
-// Precondition: cn.mu must be held.
-func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
+// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements.
+// +checklocks:cn.mu
+func (cn *conn) updateLocked(pkt *PacketBuffer, dir direction) {
+ if pkt.TransportProtocolNumber != header.TCPProtocolNumber {
+ return
+ }
+
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
+
// Update the state of tcb. tcb assumes it's always initialized on the
// client. However, we only need to know whether the connection is
// established or not, so the client/server distinction isn't important.
if cn.tcb.IsEmpty() {
cn.tcb.Init(tcpHeader)
- } else if hook == cn.tcbHook {
+ return
+ }
+
+ switch dir {
+ case dirOriginal:
cn.tcb.UpdateStateOutbound(tcpHeader)
- } else {
+ case dirReply:
cn.tcb.UpdateStateInbound(tcpHeader)
+ default:
+ panic(fmt.Sprintf("unhandled dir = %d", dir))
}
}
@@ -194,44 +208,34 @@ type ConnTrack struct {
// It is immutable.
seed uint32
+ mu sync.RWMutex `state:"nosave"`
// mu protects the buckets slice, but not buckets' contents. Only take
// the write lock if you are modifying the slice or saving for S/R.
- mu sync.RWMutex `state:"nosave"`
-
- // buckets is protected by mu.
+ //
+ // +checklocks:mu
buckets []bucket
}
// +stateify savable
type bucket struct {
- // mu protects tuples.
- mu sync.Mutex `state:"nosave"`
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
tuples tupleList
}
-// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
-// TCP header.
-//
-// Preconditions: pkt.NetworkHeader() is valid.
-func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) {
- netHeader := pkt.Network()
- if netHeader.TransportProtocol() != header.TCPProtocolNumber {
- return tupleID{}, &tcpip.ErrUnknownProtocol{}
- }
-
- tcpHeader := header.TCP(pkt.TransportHeader().View())
- if len(tcpHeader) < header.TCPMinimumSize {
- return tupleID{}, &tcpip.ErrUnknownProtocol{}
+func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool) {
+ switch pkt.TransportProtocolNumber {
+ case header.TCPProtocolNumber:
+ if tcpHeader := header.TCP(pkt.TransportHeader().View()); len(tcpHeader) >= header.TCPMinimumSize {
+ return tcpHeader, true
+ }
+ case header.UDPProtocolNumber:
+ if udpHeader := header.UDP(pkt.TransportHeader().View()); len(udpHeader) >= header.UDPMinimumSize {
+ return udpHeader, true
+ }
}
- return tupleID{
- srcAddr: netHeader.SourceAddress(),
- srcPort: tcpHeader.SourcePort(),
- dstAddr: netHeader.DestinationAddress(),
- dstPort: tcpHeader.DestinationPort(),
- transProto: netHeader.TransportProtocol(),
- netProto: pkt.NetworkProtocolNumber,
- }, nil
+ return nil, false
}
func (ct *ConnTrack) init() {
@@ -240,167 +244,185 @@ func (ct *ConnTrack) init() {
ct.buckets = make([]bucket, numBuckets)
}
-// connFor gets the conn for pkt if it exists, or returns nil
-// if it does not. It returns an error when pkt does not contain a valid TCP
-// header.
-// TODO(gvisor.dev/issue/6168): Support UDP.
-func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return nil, dirOriginal
+func (ct *ConnTrack) getConnOrMaybeInsertNoop(pkt *PacketBuffer) *tuple {
+ netHeader := pkt.Network()
+ transportHeader, ok := getTransportHeader(pkt)
+ if !ok {
+ return nil
}
- return ct.connForTID(tid)
-}
-func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
- bucket := ct.bucket(tid)
- now := time.Now()
+ tid := tupleID{
+ srcAddr: netHeader.SourceAddress(),
+ srcPort: transportHeader.SourcePort(),
+ dstAddr: netHeader.DestinationAddress(),
+ dstPort: transportHeader.DestinationPort(),
+ transProto: pkt.TransportProtocolNumber,
+ netProto: pkt.NetworkProtocolNumber,
+ }
+
+ bktID := ct.bucket(tid)
ct.mu.RLock()
- defer ct.mu.RUnlock()
- ct.buckets[bucket].mu.Lock()
- defer ct.buckets[bucket].mu.Unlock()
-
- // Iterate over the tuples in a bucket, cleaning up any unused
- // connections we find.
- for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() {
- // Clean up any timed-out connections we happen to find.
- if ct.reapTupleLocked(other, bucket, now) {
- // The tuple expired.
- continue
- }
- if tid == other.tupleID {
- return other.conn, other.direction
- }
+ bkt := &ct.buckets[bktID]
+ ct.mu.RUnlock()
+
+ now := time.Now()
+ if t := bkt.connForTID(tid, now); t != nil {
+ return t
}
- return nil, dirOriginal
-}
+ bkt.mu.Lock()
+ defer bkt.mu.Unlock()
-func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return nil
+ // Make sure a connection wasn't added between when we last checked the
+ // bucket and acquired the bucket's write lock.
+ if t := bkt.connForTIDRLocked(tid, now); t != nil {
+ return t
}
- if hook != Prerouting && hook != Output {
- return nil
+
+ // This is the first packet we're seeing for the connection. Create an entry
+ // for this new connection.
+ conn := &conn{
+ ct: ct,
+ original: tuple{tupleID: tid, direction: dirOriginal},
+ reply: tuple{tupleID: tid.reply(), direction: dirReply},
+ manip: manipNone,
+ lastUsed: now,
}
+ conn.original.conn = conn
+ conn.reply.conn = conn
+
+ // For now, we only map an entry for the packet's original tuple as NAT may be
+ // performed on this connection. Until the packet goes through all the hooks
+ // and its final address/port is known, we cannot know what the response
+ // packet's addresses/ports will look like.
+ //
+ // This is okay because the destination cannot send its response until it
+ // receives the packet; the packet will only be received once all the hooks
+ // have been performed.
+ //
+ // See (*conn).finalize.
+ bkt.tuples.PushFront(&conn.original)
+ return &conn.original
+}
- replyTID := tid.reply()
- replyTID.srcAddr = address
- replyTID.srcPort = port
+func (ct *ConnTrack) connForTID(tid tupleID) *tuple {
+ bktID := ct.bucket(tid)
- conn, _ := ct.connForTID(tid)
- if conn != nil {
- // The connection is already tracked.
- // TODO(gvisor.dev/issue/5696): Support updating an existing connection.
- return nil
- }
- conn = newConn(tid, replyTID, manipDestination, hook)
- ct.insertConn(conn)
- return conn
+ ct.mu.RLock()
+ bkt := &ct.buckets[bktID]
+ ct.mu.RUnlock()
+
+ return bkt.connForTID(tid, time.Now())
}
-func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return nil
- }
- if hook != Input && hook != Postrouting {
- return nil
+func (bkt *bucket) connForTID(tid tupleID, now time.Time) *tuple {
+ bkt.mu.RLock()
+ defer bkt.mu.RUnlock()
+ return bkt.connForTIDRLocked(tid, now)
+}
+
+// +checklocks:bkt.mu
+func (bkt *bucket) connForTIDRLocked(tid tupleID, now time.Time) *tuple {
+ for other := bkt.tuples.Front(); other != nil; other = other.Next() {
+ if tid == other.id() && !other.conn.timedOut(now) {
+ return other
+ }
}
+ return nil
+}
- replyTID := tid.reply()
- replyTID.dstAddr = address
- replyTID.dstPort = port
+func (ct *ConnTrack) finalize(cn *conn) {
+ tid := cn.reply.id()
+ id := ct.bucket(tid)
- conn, _ := ct.connForTID(tid)
- if conn != nil {
- // The connection is already tracked.
- // TODO(gvisor.dev/issue/5696): Support updating an existing connection.
- return nil
+ ct.mu.RLock()
+ bkt := &ct.buckets[id]
+ ct.mu.RUnlock()
+
+ bkt.mu.Lock()
+ defer bkt.mu.Unlock()
+
+ if t := bkt.connForTIDRLocked(tid, time.Now()); t != nil {
+ // Another connection for the reply already exists. We can't do much about
+ // this so we leave the connection cn represents in a state where it can
+ // send packets but its responses will be mapped to some other connection.
+ // This may be okay if the connection only expects to send packets without
+ // any responses.
+ return
}
- conn = newConn(tid, replyTID, manipSource, hook)
- ct.insertConn(conn)
- return conn
+
+ bkt.tuples.PushFront(&cn.reply)
}
-// insertConn inserts conn into the appropriate table bucket.
-func (ct *ConnTrack) insertConn(conn *conn) {
- // Lock the buckets in the correct order.
- tupleBucket := ct.bucket(conn.original.tupleID)
- replyBucket := ct.bucket(conn.reply.tupleID)
- ct.mu.RLock()
- defer ct.mu.RUnlock()
- if tupleBucket < replyBucket {
- ct.buckets[tupleBucket].mu.Lock()
- ct.buckets[replyBucket].mu.Lock()
- } else if tupleBucket > replyBucket {
- ct.buckets[replyBucket].mu.Lock()
- ct.buckets[tupleBucket].mu.Lock()
- } else {
- // Both tuples are in the same bucket.
- ct.buckets[tupleBucket].mu.Lock()
- }
-
- // Now that we hold the locks, ensure the tuple hasn't been inserted by
- // another thread.
- // TODO(gvisor.dev/issue/5773): Should check conn.reply.tupleID, too?
- alreadyInserted := false
- for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() {
- if other.tupleID == conn.original.tupleID {
- alreadyInserted = true
- break
+func (cn *conn) finalize() {
+ {
+ cn.mu.RLock()
+ finalized := cn.finalized
+ cn.mu.RUnlock()
+ if finalized {
+ return
}
}
- if !alreadyInserted {
- // Add the tuple to the map.
- ct.buckets[tupleBucket].tuples.PushFront(&conn.original)
- ct.buckets[replyBucket].tuples.PushFront(&conn.reply)
+ cn.mu.Lock()
+ finalized := cn.finalized
+ cn.finalized = true
+ cn.mu.Unlock()
+ if finalized {
+ return
}
- // Unlocking can happen in any order.
- ct.buckets[tupleBucket].mu.Unlock()
- if tupleBucket != replyBucket {
- ct.buckets[replyBucket].mu.Unlock() // +checklocksforce
- }
+ cn.ct.finalize(cn)
}
-// handlePacket will manipulate the port and address of the packet if the
-// connection exists. Returns whether, after the packet traverses the tables,
-// it should create a new entry in the table.
-func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
- if pkt.NatDone {
- return false
+// performNAT setups up the connection for the specified NAT.
+//
+// Generally, only the first packet of a connection reaches this method; other
+// other packets will be manipulated without needing to modify the connection.
+func (cn *conn) performNAT(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, dnat bool) {
+ cn.performNATIfNoop(port, address, dnat)
+ cn.handlePacket(pkt, hook, r)
+}
+
+func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool) {
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
+
+ if cn.finalized {
+ return
}
- switch hook {
- case Prerouting, Input, Output, Postrouting:
- default:
- return false
+ if cn.manip != manipNone {
+ return
}
- // TODO(gvisor.dev/issue/6168): Support UDP.
- if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
- return false
+ cn.reply.mu.Lock()
+ defer cn.reply.mu.Unlock()
+
+ if dnat {
+ cn.reply.tupleID.srcAddr = address
+ cn.reply.tupleID.srcPort = port
+ cn.manip = manipDestination
+ } else {
+ cn.reply.tupleID.dstAddr = address
+ cn.reply.tupleID.dstPort = port
+ cn.manip = manipSource
}
+}
- conn, dir := ct.connFor(pkt)
- // Connection not found for the packet.
- if conn == nil {
- // If this is the last hook in the data path for this packet (Input if
- // incoming, Postrouting if outgoing), indicate that a connection should be
- // inserted by the end of this hook.
- return hook == Input || hook == Postrouting
+func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) {
+ if pkt.NatDone {
+ return
}
- netHeader := pkt.Network()
- tcpHeader := header.TCP(pkt.TransportHeader().View())
- if len(tcpHeader) < header.TCPMinimumSize {
- return false
+ transportHeader, ok := getTransportHeader(pkt)
+ if !ok {
+ return
}
+ netHeader := pkt.Network()
+
// TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be
// validated if checksum offloading is off. It may require IP defrag if the
// packets are fragmented.
@@ -410,49 +432,58 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
updateSRCFields := false
+ dir := pkt.tuple.direction
+
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
+
switch hook {
case Prerouting, Output:
- if conn.manip == manipDestination {
- switch dir {
- case dirOriginal:
- newPort = conn.reply.srcPort
- newAddr = conn.reply.srcAddr
- case dirReply:
- newPort = conn.original.dstPort
- newAddr = conn.original.dstAddr
-
- updateSRCFields = true
- }
+ if cn.manip == manipDestination && dir == dirOriginal {
+ id := cn.reply.id()
+ newPort = id.srcPort
+ newAddr = id.srcAddr
+ pkt.NatDone = true
+ } else if cn.manip == manipSource && dir == dirReply {
+ id := cn.original.id()
+ newPort = id.srcPort
+ newAddr = id.srcAddr
pkt.NatDone = true
}
case Input, Postrouting:
- if conn.manip == manipSource {
- switch dir {
- case dirOriginal:
- newPort = conn.reply.dstPort
- newAddr = conn.reply.dstAddr
-
- updateSRCFields = true
- case dirReply:
- newPort = conn.original.srcPort
- newAddr = conn.original.srcAddr
- }
+ if cn.manip == manipSource && dir == dirOriginal {
+ id := cn.reply.id()
+ newPort = id.dstPort
+ newAddr = id.dstAddr
+ updateSRCFields = true
+ pkt.NatDone = true
+ } else if cn.manip == manipDestination && dir == dirReply {
+ id := cn.original.id()
+ newPort = id.dstPort
+ newAddr = id.dstAddr
+ updateSRCFields = true
pkt.NatDone = true
}
default:
panic(fmt.Sprintf("unrecognized hook = %s", hook))
}
+
if !pkt.NatDone {
- return false
+ return
}
fullChecksum := false
updatePseudoHeader := false
switch hook {
- case Prerouting, Input:
+ case Prerouting:
+ // Packet came from outside the stack so it must have a checksum set
+ // already.
+ fullChecksum = true
+ updatePseudoHeader = true
+ case Input:
case Output, Postrouting:
// Calculate the TCP checksum and set it.
- if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
+ if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
updatePseudoHeader = true
} else if r.RequiresTXTransportChecksum() {
fullChecksum = true
@@ -464,7 +495,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
rewritePacket(
netHeader,
- tcpHeader,
+ transportHeader,
updateSRCFields,
fullChecksum,
updatePseudoHeader,
@@ -472,46 +503,10 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
newAddr,
)
- // Update the state of tcb.
- conn.mu.Lock()
- defer conn.mu.Unlock()
-
// Mark the connection as having been used recently so it isn't reaped.
- conn.lastUsed = time.Now()
+ cn.lastUsed = time.Now()
// Update connection state.
- conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
-
- return false
-}
-
-// maybeInsertNoop tries to insert a no-op connection entry to keep connections
-// from getting clobbered when replies arrive. It only inserts if there isn't
-// already a connection for pkt.
-//
-// This should be called after traversing iptables rules only, to ensure that
-// pkt.NatDone is set correctly.
-func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
- // If there were a rule applying to this packet, it would be marked
- // with NatDone.
- if pkt.NatDone {
- return
- }
-
- // We only track TCP connections.
- if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
- return
- }
-
- // This is the first packet we're seeing for the TCP connection. Insert
- // the noop entry (an identity mapping) so that the response doesn't
- // get NATed, breaking the connection.
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return
- }
- conn := newConn(tid, tid.reply(), manipNone, hook)
- conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
- ct.insertConn(conn)
+ cn.updateLocked(pkt, dir)
}
// bucket gets the conntrack bucket for a tupleID.
@@ -563,14 +558,15 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim
defer ct.mu.RUnlock()
for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ {
idx = (i + start) % len(ct.buckets)
- ct.buckets[idx].mu.Lock()
- for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() {
+ bkt := &ct.buckets[idx]
+ bkt.mu.Lock()
+ for tuple := bkt.tuples.Front(); tuple != nil; tuple = tuple.Next() {
checked++
- if ct.reapTupleLocked(tuple, idx, now) {
+ if ct.reapTupleLocked(tuple, idx, bkt, now) {
expired++
}
}
- ct.buckets[idx].mu.Unlock()
+ bkt.mu.Unlock()
}
// We already checked buckets[idx].
idx++
@@ -595,44 +591,48 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim
// reapTupleLocked tries to remove tuple and its reply from the table. It
// returns whether the tuple's connection has timed out.
//
-// Preconditions:
-// * ct.mu is locked for reading.
-// * bucket is locked.
-func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool {
+// Precondition: ct.mu is read locked and bkt.mu is write locked.
+// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements.
+// +checklocks:ct.mu
+// +checklocks:bkt.mu
+func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now time.Time) bool {
if !tuple.conn.timedOut(now) {
return false
}
// To maintain lock order, we can only reap these tuples if the reply
// appears later in the table.
- replyBucket := ct.bucket(tuple.reply())
- if bucket > replyBucket {
+ replyBktID := ct.bucket(tuple.id().reply())
+ if bktID > replyBktID {
return true
}
// Don't re-lock if both tuples are in the same bucket.
- differentBuckets := bucket != replyBucket
- if differentBuckets {
- ct.buckets[replyBucket].mu.Lock()
+ if bktID != replyBktID {
+ replyBkt := &ct.buckets[replyBktID]
+ replyBkt.mu.Lock()
+ removeConnFromBucket(replyBkt, tuple)
+ replyBkt.mu.Unlock()
+ } else {
+ removeConnFromBucket(bkt, tuple)
}
// We have the buckets locked and can remove both tuples.
+ bkt.tuples.Remove(tuple)
+ return true
+}
+
+// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements.
+// +checklocks:b.mu
+func removeConnFromBucket(b *bucket, tuple *tuple) {
if tuple.direction == dirOriginal {
- ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply)
+ b.tuples.Remove(&tuple.conn.reply)
} else {
- ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original)
- }
- ct.buckets[bucket].tuples.Remove(tuple)
-
- // Don't re-unlock if both tuples are in the same bucket.
- if differentBuckets {
- ct.buckets[replyBucket].mu.Unlock() // +checklocksforce
+ b.tuples.Remove(&tuple.conn.original)
}
-
- return true
}
-func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
+func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
// Lookup the connection. The reply's original destination
// describes the original address.
tid := tupleID{
@@ -640,17 +640,22 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ
srcPort: epID.LocalPort,
dstAddr: epID.RemoteAddress,
dstPort: epID.RemotePort,
- transProto: header.TCPProtocolNumber,
+ transProto: transProto,
netProto: netProto,
}
- conn, _ := ct.connForTID(tid)
- if conn == nil {
+ t := ct.connForTID(tid)
+ if t == nil {
// Not a tracked connection.
return "", 0, &tcpip.ErrNotConnected{}
- } else if conn.manip != manipDestination {
+ }
+
+ t.conn.mu.RLock()
+ defer t.conn.mu.RUnlock()
+ if t.conn.manip != manipDestination {
// Unmanipulated destination.
return "", 0, &tcpip.ErrInvalidOptionValue{}
}
- return conn.original.dstAddr, conn.original.dstPort, nil
+ id := t.conn.original.id()
+ return id.dstAddr, id.dstPort, nil
}
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 72f66441f..c2f1f4798 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -181,10 +181,6 @@ func (*fwdTestNetworkProtocol) MinimumPacketSize() int {
return fwdTestNetHeaderLen
}
-func (*fwdTestNetworkProtocol) DefaultPrefixLen() int {
- return fwdTestNetDefaultPrefixLen
-}
-
func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
}
@@ -342,6 +338,10 @@ func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, pkts PacketBufferList, p
return n, nil
}
+func (*fwdTestLinkEndpoint) WriteRawPacket(*PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
// Wait implements stack.LinkEndpoint.Wait.
func (*fwdTestLinkEndpoint) Wait() {}
@@ -380,8 +380,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.M
if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC #1 failed:", err)
}
- if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress #1 failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fwdTestNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fwdTestNetDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
// NIC 2 has the link address "b", and added the network address 2.
@@ -393,8 +400,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.M
if err := s.CreateNIC(2, ep2); err != nil {
t.Fatal("CreateNIC #2 failed:", err)
}
- if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress #2 failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fwdTestNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fwdTestNetDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr2, AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err)
}
nic, ok := s.nics[2]
diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go
index 3a20839da..99e5d2df7 100644
--- a/pkg/tcpip/stack/icmp_rate_limit.go
+++ b/pkg/tcpip/stack/icmp_rate_limit.go
@@ -16,6 +16,7 @@ package stack
import (
"golang.org/x/time/rate"
+ "gvisor.dev/gvisor/pkg/tcpip"
)
const (
@@ -31,11 +32,41 @@ const (
// ICMPRateLimiter is a global rate limiter that controls the generation of
// ICMP messages generated by the stack.
type ICMPRateLimiter struct {
- *rate.Limiter
+ limiter *rate.Limiter
+ clock tcpip.Clock
}
// NewICMPRateLimiter returns a global rate limiter for controlling the rate
-// at which ICMP messages are generated by the stack.
-func NewICMPRateLimiter() *ICMPRateLimiter {
- return &ICMPRateLimiter{Limiter: rate.NewLimiter(icmpLimit, icmpBurst)}
+// at which ICMP messages are generated by the stack. The returned limiter
+// does not apply limits to any ICMP types by default.
+func NewICMPRateLimiter(clock tcpip.Clock) *ICMPRateLimiter {
+ return &ICMPRateLimiter{
+ clock: clock,
+ limiter: rate.NewLimiter(icmpLimit, icmpBurst),
+ }
+}
+
+// SetLimit sets a new Limit for the limiter.
+func (l *ICMPRateLimiter) SetLimit(limit rate.Limit) {
+ l.limiter.SetLimitAt(l.clock.Now(), limit)
+}
+
+// Limit returns the maximum overall event rate.
+func (l *ICMPRateLimiter) Limit() rate.Limit {
+ return l.limiter.Limit()
+}
+
+// SetBurst sets a new burst size for the limiter.
+func (l *ICMPRateLimiter) SetBurst(burst int) {
+ l.limiter.SetBurstAt(l.clock.Now(), burst)
+}
+
+// Burst returns the maximum burst size.
+func (l *ICMPRateLimiter) Burst() int {
+ return l.limiter.Burst()
+}
+
+// Allow reports whether one ICMP message may be sent now.
+func (l *ICMPRateLimiter) Allow() bool {
+ return l.limiter.AllowN(l.clock.Now(), 1)
}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index f152c0d83..5808be685 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -264,26 +264,134 @@ const (
chainReturn
)
-// Check runs pkt through the rules for hook. It returns true when the packet
-// should continue traversing the network stack and false when it should be
-// dropped.
+// CheckPrerouting performs the prerouting hook on the packet.
+//
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
//
-// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool {
- if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber {
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndpoint, inNicName string) bool {
+ const hook = Prerouting
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
return true
}
+
+ if t := it.connections.getConnOrMaybeInsertNoop(pkt); t != nil {
+ pkt.tuple = t
+ t.conn.handlePacket(pkt, hook, nil /* route */)
+ }
+
+ return it.check(hook, pkt, nil /* route */, addressEP, inNicName, "" /* outNicName */)
+}
+
+// CheckInput performs the input hook on the packet.
+//
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool {
+ const hook = Input
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+
+ if t := pkt.tuple; t != nil {
+ t.conn.handlePacket(pkt, hook, nil /* route */)
+ }
+
+ ret := it.check(hook, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */)
+ if t := pkt.tuple; t != nil {
+ t.conn.finalize()
+ }
+ pkt.tuple = nil
+ return ret
+}
+
+// CheckForward performs the forward hook on the packet.
+//
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckForward(pkt *PacketBuffer, inNicName, outNicName string) bool {
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+ return it.check(Forward, pkt, nil /* route */, nil /* addressEP */, inNicName, outNicName)
+}
+
+// CheckOutput performs the output hook on the packet.
+//
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string) bool {
+ const hook = Output
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+
+ if t := it.connections.getConnOrMaybeInsertNoop(pkt); t != nil {
+ pkt.tuple = t
+ t.conn.handlePacket(pkt, hook, r)
+ }
+
+ return it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName)
+}
+
+// CheckPostrouting performs the postrouting hook on the packet.
+//
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, outNicName string) bool {
+ const hook = Postrouting
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+
+ if t := pkt.tuple; t != nil {
+ t.conn.handlePacket(pkt, hook, r)
+ }
+
+ ret := it.check(hook, pkt, r, addressEP, "" /* inNicName */, outNicName)
+ if t := pkt.tuple; t != nil {
+ t.conn.finalize()
+ }
+ pkt.tuple = nil
+ return ret
+}
+
+func (it *IPTables) shouldSkip(netProto tcpip.NetworkProtocolNumber) bool {
+ switch netProto {
+ case header.IPv4ProtocolNumber, header.IPv6ProtocolNumber:
+ default:
+ // IPTables only supports IPv4/IPv6.
+ return true
+ }
+
+ it.mu.RLock()
+ defer it.mu.RUnlock()
// Many users never configure iptables. Spare them the cost of rule
// traversal if rules have never been set.
+ return !it.modified
+}
+
+// check runs pkt through the rules for hook. It returns true when the packet
+// should continue traversing the network stack and false when it should be
+// dropped.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool {
it.mu.RLock()
defer it.mu.RUnlock()
- if !it.modified {
- return true
- }
-
- // Packets are manipulated only if connection and matching
- // NAT rule exists.
- shouldTrack := it.connections.handlePacket(pkt, hook, r)
// Go through each table containing the hook.
priorities := it.priorities[hook]
@@ -300,7 +408,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr
table = it.v4Tables[tableID]
}
ruleIdx := table.BuiltinChains[hook]
- switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict {
// If the table returns Accept, move on to the next table.
case chainAccept:
continue
@@ -311,7 +419,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr
// Any Return from a built-in chain means we have to
// call the underflow.
underflow := table.Rules[table.Underflows[hook]]
- switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, preroutingAddr); v {
+ switch v, _ := underflow.Target.Action(pkt, hook, r, addressEP); v {
case RuleAccept:
continue
case RuleDrop:
@@ -327,21 +435,6 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr
}
}
- // If this connection should be tracked, try to add an entry for it. If
- // traversing the nat table didn't end in adding an entry,
- // maybeInsertNoop will add a no-op entry for the connection. This is
- // needeed when establishing connections so that the SYN/ACK reply to an
- // outgoing SYN is delivered to the correct endpoint rather than being
- // redirected by a prerouting rule.
- //
- // From the iptables documentation: "If there is no rule, a `null'
- // binding is created: this usually does not map the packet, but exists
- // to ensure we don't map another stream over an existing one."
- if shouldTrack {
- it.connections.maybeInsertNoop(pkt, hook)
- }
-
- // Every table returned Accept.
return true
}
@@ -375,19 +468,32 @@ func (it *IPTables) startReaper(interval time.Duration) {
}()
}
-// CheckPackets runs pkts through the rules for hook and returns a map of packets that
-// should not go forward.
+// CheckOutputPackets performs the output hook on the packets.
//
-// Preconditions:
-// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-// * pkt.NetworkHeader is not nil.
+// Returns a map of packets that must be dropped.
//
-// NOTE: unlike the Check API the returned map contains packets that should be
-// dropped.
-func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+// Precondition: The packets' network and transport header must be set.
+func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+ return checkPackets(pkts, func(pkt *PacketBuffer) bool {
+ return it.CheckOutput(pkt, r, outNicName)
+ })
+}
+
+// CheckPostroutingPackets performs the postrouting hook on the packets.
+//
+// Returns a map of packets that must be dropped.
+//
+// Precondition: The packets' network and transport header must be set.
+func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, addressEP AddressableEndpoint, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+ return checkPackets(pkts, func(pkt *PacketBuffer) bool {
+ return it.CheckPostrouting(pkt, r, addressEP, outNicName)
+ })
+}
+
+func checkPackets(pkts PacketBufferList, f func(*PacketBuffer) bool) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if !pkt.NatDone {
- if ok := it.Check(hook, pkt, r, "", inNicName, outNicName); !ok {
+ if ok := f(pkt); !ok {
if drop == nil {
drop = make(map[*PacketBuffer]struct{})
}
@@ -407,11 +513,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inN
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict {
+func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
for ruleIdx < len(table.Rules) {
- switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict {
case RuleAccept:
return chainAccept
@@ -428,7 +534,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
ruleIdx++
continue
}
- switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, preroutingAddr, inNicName, outNicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, addressEP, inNicName, outNicName); verdict {
case chainAccept:
return chainAccept
case chainDrop:
@@ -454,7 +560,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) {
+func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// Check whether the packet matches the IP header filter.
@@ -477,16 +583,16 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
}
// All the matchers matched, so run the target.
- return rule.Target.Action(pkt, &it.connections, hook, r, preroutingAddr)
+ return rule.Target.Action(pkt, hook, r, addressEP)
}
// OriginalDst returns the original destination of redirected connections. It
// returns an error if the connection doesn't exist or isn't redirected.
-func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
+func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
it.mu.RLock()
defer it.mu.RUnlock()
if !it.modified {
return "", 0, &tcpip.ErrNotConnected{}
}
- return it.connections.originalDst(epID, netProto)
+ return it.connections.originalDst(epID, netProto, transProto)
}
diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/tcpip/stack/iptables_state.go
index 529e02a07..3d3c39c20 100644
--- a/pkg/tcpip/stack/iptables_state.go
+++ b/pkg/tcpip/stack/iptables_state.go
@@ -26,11 +26,15 @@ type unixTime struct {
// saveLastUsed is invoked by stateify.
func (cn *conn) saveLastUsed() unixTime {
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()}
}
// loadLastUsed is invoked by stateify.
func (cn *conn) loadLastUsed(unix unixTime) {
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
cn.lastUsed = time.Unix(unix.second, unix.nano)
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 96cc899bb..7e5a1672a 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -29,7 +29,7 @@ type AcceptTarget struct {
}
// Action implements Target.Action.
-func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*AcceptTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleAccept, 0
}
@@ -40,7 +40,7 @@ type DropTarget struct {
}
// Action implements Target.Action.
-func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*DropTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleDrop, 0
}
@@ -52,7 +52,7 @@ type ErrorTarget struct {
}
// Action implements Target.Action.
-func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ErrorTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
return RuleDrop, 0
}
@@ -67,7 +67,7 @@ type UserChainTarget struct {
}
// Action implements Target.Action.
-func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*UserChainTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
@@ -79,7 +79,7 @@ type ReturnTarget struct {
}
// Action implements Target.Action.
-func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ReturnTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleReturn, 0
}
@@ -97,7 +97,7 @@ type RedirectTarget struct {
}
// Action implements Target.Action.
-func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) {
+func (rt *RedirectTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
// Sanity check.
if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
@@ -117,6 +117,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r
// Change the address to loopback (127.0.0.1 or ::1) in Output and to
// the primary address of the incoming interface in Prerouting.
+ var address tcpip.Address
switch hook {
case Output:
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
@@ -125,48 +126,18 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r
address = header.IPv6Loopback
}
case Prerouting:
- // No-op, as address is already set correctly.
+ // addressEP is expected to be set for the prerouting hook.
+ address = addressEP.MainAddress().Address
default:
panic("redirect target is supported only on output and prerouting hooks")
}
- switch protocol := pkt.TransportProtocolNumber; protocol {
- case header.UDPProtocolNumber:
- udpHeader := header.UDP(pkt.TransportHeader().View())
-
- if hook == Output {
- // Only calculate the checksum if offloading isn't supported.
- requiresChecksum := r.RequiresTXTransportChecksum()
- rewritePacket(
- pkt.Network(),
- udpHeader,
- false, /* updateSRCFields */
- requiresChecksum,
- requiresChecksum,
- rt.Port,
- address,
- )
- } else {
- udpHeader.SetDestinationPort(rt.Port)
- }
-
- pkt.NatDone = true
- case header.TCPProtocolNumber:
- if ct == nil {
- return RuleAccept, 0
- }
-
- // Set up conection for matching NAT rule. Only the first
- // packet of the connection comes here. Other packets will be
- // manipulated in connection tracking.
- if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil {
- ct.handlePacket(pkt, hook, r)
- }
- default:
- return RuleDrop, 0
+ if t := pkt.tuple; t != nil {
+ t.conn.performNAT(pkt, hook, r, rt.Port, address, true /* dnat */)
+ return RuleAccept, 0
}
- return RuleAccept, 0
+ return RuleDrop, 0
}
// SNATTarget modifies the source port/IP in the outgoing packets.
@@ -179,15 +150,7 @@ type SNATTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// Action implements Target.Action.
-func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) {
- // Sanity check.
- if st.NetworkProtocol != pkt.NetworkProtocolNumber {
- panic(fmt.Sprintf(
- "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
- st.NetworkProtocol, pkt.NetworkProtocolNumber))
- }
-
+func snatAction(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address) (RuleVerdict, int) {
// Packet is already manipulated.
if pkt.NatDone {
return RuleAccept, 0
@@ -198,6 +161,33 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
return RuleDrop, 0
}
+ // TODO(https://gvisor.dev/issue/5773): If the port is in use, pick a
+ // different port.
+ if port == 0 {
+ switch protocol := pkt.TransportProtocolNumber; protocol {
+ case header.UDPProtocolNumber:
+ port = header.UDP(pkt.TransportHeader().View()).SourcePort()
+ case header.TCPProtocolNumber:
+ port = header.TCP(pkt.TransportHeader().View()).SourcePort()
+ }
+ }
+
+ if t := pkt.tuple; t != nil {
+ t.conn.performNAT(pkt, hook, r, port, address, false /* dnat */)
+ }
+
+ return RuleAccept, 0
+}
+
+// Action implements Target.Action.
+func (st *SNATTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) {
+ // Sanity check.
+ if st.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ st.NetworkProtocol, pkt.NetworkProtocolNumber))
+ }
+
switch hook {
case Postrouting, Input:
case Prerouting, Output, Forward:
@@ -206,37 +196,43 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
panic(fmt.Sprintf("%s unrecognized", hook))
}
- switch protocol := pkt.TransportProtocolNumber; protocol {
- case header.UDPProtocolNumber:
- // Only calculate the checksum if offloading isn't supported.
- requiresChecksum := r.RequiresTXTransportChecksum()
- rewritePacket(
- pkt.Network(),
- header.UDP(pkt.TransportHeader().View()),
- true, /* updateSRCFields */
- requiresChecksum,
- requiresChecksum,
- st.Port,
- st.Addr,
- )
-
- pkt.NatDone = true
- case header.TCPProtocolNumber:
- if ct == nil {
- return RuleAccept, 0
- }
+ return snatAction(pkt, hook, r, st.Port, st.Addr)
+}
- // Set up conection for matching NAT rule. Only the first
- // packet of the connection comes here. Other packets will be
- // manipulated in connection tracking.
- if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil {
- ct.handlePacket(pkt, hook, r)
- }
+// MasqueradeTarget modifies the source port/IP in the outgoing packets.
+type MasqueradeTarget struct {
+ // NetworkProtocol is the network protocol the target is used with. It
+ // is immutable.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// Action implements Target.Action.
+func (mt *MasqueradeTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
+ // Sanity check.
+ if mt.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "MasqueradeTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ mt.NetworkProtocol, pkt.NetworkProtocolNumber))
+ }
+
+ switch hook {
+ case Postrouting:
+ case Prerouting, Input, Forward, Output:
+ panic(fmt.Sprintf("masquerade target is supported only on postrouting hook; hook = %d", hook))
default:
+ panic(fmt.Sprintf("%s unrecognized", hook))
+ }
+
+ // addressEP is expected to be set for the postrouting hook.
+ ep := addressEP.AcquireOutgoingPrimaryAddress(pkt.Network().DestinationAddress(), false /* allowExpired */)
+ if ep == nil {
+ // No address exists that we can use as a source address.
return RuleDrop, 0
}
- return RuleAccept, 0
+ address := ep.AddressWithPrefix().Address
+ ep.DecRef()
+ return snatAction(pkt, hook, r, 0 /* port */, address)
}
func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) {
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 66e5f22ac..b22024667 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -81,17 +81,6 @@ const (
//
// +stateify savable
type IPTables struct {
- // mu protects v4Tables, v6Tables, and modified.
- mu sync.RWMutex
- // v4Tables and v6tables map tableIDs to tables. They hold builtin
- // tables only, not user tables. mu must be locked for accessing.
- v4Tables [NumTables]Table
- v6Tables [NumTables]Table
- // modified is whether tables have been modified at least once. It is
- // used to elide the iptables performance overhead for workloads that
- // don't utilize iptables.
- modified bool
-
// priorities maps each hook to a list of table names. The order of the
// list is the order in which each table should be visited for that
// hook. It is immutable.
@@ -101,6 +90,21 @@ type IPTables struct {
// reaperDone can be signaled to stop the reaper goroutine.
reaperDone chan struct{}
+
+ mu sync.RWMutex
+ // v4Tables and v6tables map tableIDs to tables. They hold builtin
+ // tables only, not user tables.
+ //
+ // +checklocks:mu
+ v4Tables [NumTables]Table
+ // +checklocks:mu
+ v6Tables [NumTables]Table
+ // modified is whether tables have been modified at least once. It is
+ // used to elide the iptables performance overhead for workloads that
+ // don't utilize iptables.
+ //
+ // +checklocks:mu
+ modified bool
}
// VisitTargets traverses all the targets of all tables and replaces each with
@@ -352,5 +356,5 @@ type Target interface {
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
- Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int)
+ Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int)
}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 4d5431da1..40b33b6b5 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -333,8 +333,12 @@ func TestDADDisabled(t *testing.T) {
Address: addr1,
PrefixLen: defaultPrefixLen,
}
- if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err)
}
// Should get the address immediately since we should not have performed
@@ -379,12 +383,15 @@ func TestDADResolveLoopback(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- addrWithPrefix := tcpip.AddressWithPrefix{
- Address: addr1,
- PrefixLen: defaultPrefixLen,
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr1,
+ PrefixLen: defaultPrefixLen,
+ },
}
- if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err)
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
@@ -517,8 +524,12 @@ func TestDADResolve(t *testing.T) {
Address: addr1,
PrefixLen: defaultPrefixLen,
}
- if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err)
}
// Make sure the address does not resolve before the resolution time has
@@ -740,8 +751,12 @@ func TestDADFail(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// Address should not be considered bound to the NIC yet
@@ -778,8 +793,8 @@ func TestDADFail(t *testing.T) {
// Attempting to add the address again should not fail if the address's
// state was cleaned up when DAD failed.
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
})
}
@@ -851,8 +866,12 @@ func TestDADStop(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
@@ -975,17 +994,29 @@ func TestSetNDPConfigurations(t *testing.T) {
// Add addresses for each NIC.
addrWithPrefix1 := tcpip.AddressWithPrefix{Address: addr1, PrefixLen: defaultPrefixLen}
- if err := s.AddAddressWithPrefix(nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addrWithPrefix1, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix1,
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID1, protocolAddr1, err)
}
addrWithPrefix2 := tcpip.AddressWithPrefix{Address: addr2, PrefixLen: defaultPrefixLen}
- if err := s.AddAddressWithPrefix(nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addrWithPrefix2, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix2,
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID2, protocolAddr2, err)
}
expectDADEvent(nicID2, addr2)
addrWithPrefix3 := tcpip.AddressWithPrefix{Address: addr3, PrefixLen: defaultPrefixLen}
- if err := s.AddAddressWithPrefix(nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addrWithPrefix3, err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix3,
+ }
+ if err := s.AddProtocolAddress(nicID3, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID3, protocolAddr3, err)
}
expectDADEvent(nicID3, addr3)
@@ -2788,8 +2819,12 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
continue
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, test.addrs[j].Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, test.addrs[j].Address, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: test.addrs[j].Address.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
manuallyAssignedAddresses[test.addrs[j].Address] = struct{}{}
@@ -3644,8 +3679,9 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) {
Protocol: header.IPv6ProtocolNumber,
AddressWithPrefix: addr2,
}
- if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
+ properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ if err := s.AddProtocolAddress(nicID, protoAddr2, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v) = %s", nicID, protoAddr2, properties, err)
}
// addr2 should be more preferred now since it is at the front of the primary
// list.
@@ -3733,8 +3769,9 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
}
// Add the address as a static address before SLAAC tries to add it.
- if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err)
+ protocolAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}) = %s", protocolAddr, err)
}
if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
t.Fatalf("Should have %s in the list of addresses", addr1)
@@ -4073,8 +4110,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
// Attempting to add the address manually should not fail if the
// address's state was cleaned up when DAD failed.
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if err := s.RemoveAddress(nicID, addr.Address); err != nil {
t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err)
@@ -5362,8 +5403,12 @@ func TestRouterSolicitation(t *testing.T) {
}
if addr := test.nicAddr; addr != "" {
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index b854d868c..29d580e76 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -72,9 +72,15 @@ type nic struct {
sync.RWMutex
spoofing bool
promiscuous bool
- // packetEPs is protected by mu, but the contained packetEndpointList are
- // not.
- packetEPs map[tcpip.NetworkProtocolNumber]*packetEndpointList
+ }
+
+ packetEPs struct {
+ mu sync.RWMutex
+
+ // eps is protected by the mutex, but the values contained in it are not.
+ //
+ // +checklocks:mu
+ eps map[tcpip.NetworkProtocolNumber]*packetEndpointList
}
}
@@ -91,6 +97,8 @@ type packetEndpointList struct {
mu sync.RWMutex
// eps is protected by mu, but the contained PacketEndpoint values are not.
+ //
+ // +checklocks:mu
eps []PacketEndpoint
}
@@ -111,6 +119,12 @@ func (p *packetEndpointList) remove(ep PacketEndpoint) {
}
}
+func (p *packetEndpointList) len() int {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ return len(p.eps)
+}
+
// forEach calls fn with each endpoints in p while holding the read lock on p.
func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) {
p.mu.RLock()
@@ -143,18 +157,16 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
duplicateAddressDetectors: make(map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector),
}
nic.linkResQueue.init(nic)
- nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList)
+
+ nic.packetEPs.mu.Lock()
+ defer nic.packetEPs.mu.Unlock()
+
+ nic.packetEPs.eps = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList)
resolutionRequired := ep.Capabilities()&CapabilityResolutionRequired != 0
- // Register supported packet and network endpoint protocols.
- for _, netProto := range header.Ethertypes {
- nic.mu.packetEPs[netProto] = new(packetEndpointList)
- }
for _, netProto := range stack.networkProtocols {
netNum := netProto.Number()
- nic.mu.packetEPs[netNum] = new(packetEndpointList)
-
netEP := netProto.NewEndpoint(nic, nic)
nic.networkEndpoints[netNum] = netEP
@@ -365,6 +377,8 @@ func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt
pkt.EgressRoute = r
pkt.NetworkProtocolNumber = protocol
+ n.deliverOutboundPacket(r.RemoteLinkAddress, pkt)
+
if err := n.LinkEndpoint.WritePacket(r, protocol, pkt); err != nil {
return err
}
@@ -383,6 +397,7 @@ func (n *nic) writePackets(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pk
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
pkt.EgressRoute = r
pkt.NetworkProtocolNumber = protocol
+ n.deliverOutboundPacket(r.RemoteLinkAddress, pkt)
}
writtenPackets, err := n.LinkEndpoint.WritePackets(r, pkts, protocol)
@@ -501,7 +516,7 @@ func (n *nic) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber,
// addAddress adds a new address to n, so that it starts accepting packets
// targeted at the given address (and network protocol).
-func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error {
+func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, properties AddressProperties) tcpip.Error {
ep, ok := n.networkEndpoints[protocolAddress.Protocol]
if !ok {
return &tcpip.ErrUnknownProtocol{}
@@ -512,7 +527,7 @@ func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo
return &tcpip.ErrNotSupported{}
}
- addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */)
+ addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, properties)
if err == nil {
// We have no need for the address endpoint.
addressEndpoint.DecRef()
@@ -699,12 +714,9 @@ func (n *nic) isInGroup(addr tcpip.Address) bool {
// This rule applies only to the slice itself, not to the items of the slice;
// the ownership of the items is not retained by the caller.
func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- n.mu.RLock()
enabled := n.Enabled()
// If the NIC is not yet enabled, don't receive any packets.
if !enabled {
- n.mu.RUnlock()
-
n.stats.disabledRx.packets.Increment()
n.stats.disabledRx.bytes.IncrementBy(uint64(pkt.Data().Size()))
return
@@ -715,7 +727,6 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
networkEndpoint, ok := n.networkEndpoints[protocol]
if !ok {
- n.mu.RUnlock()
n.stats.unknownL3ProtocolRcvdPackets.Increment()
return
}
@@ -727,44 +738,87 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0
- // Are any packet type sockets listening for this network protocol?
- protoEPs := n.mu.packetEPs[protocol]
- // Other packet type sockets that are listening for all protocols.
- anyEPs := n.mu.packetEPs[header.EthernetProtocolAll]
- n.mu.RUnlock()
-
// Deliver to interested packet endpoints without holding NIC lock.
+ var packetEPPkt *PacketBuffer
deliverPacketEPs := func(ep PacketEndpoint) {
- p := pkt.Clone()
- p.PktType = tcpip.PacketHost
- ep.HandlePacket(n.id, local, protocol, p)
+ if packetEPPkt == nil {
+ // Packet endpoints hold the full packet.
+ //
+ // We perform a deep copy because higher-level endpoints may point to
+ // the middle of a view that is held by a packet endpoint. Save/Restore
+ // does not support overlapping slices and will panic in this case.
+ //
+ // TODO(https://gvisor.dev/issue/6517): Avoid this copy once S/R supports
+ // overlapping slices (e.g. by passing a shallow copy of pkt to the packet
+ // endpoint).
+ packetEPPkt = NewPacketBuffer(PacketBufferOptions{
+ Data: PayloadSince(pkt.LinkHeader()).ToVectorisedView(),
+ })
+ // If a link header was populated in the original packet buffer, then
+ // populate it in the packet buffer we provide to packet endpoints as
+ // packet endpoints inspect link headers.
+ packetEPPkt.LinkHeader().Consume(pkt.LinkHeader().View().Size())
+ packetEPPkt.PktType = tcpip.PacketHost
+ }
+
+ ep.HandlePacket(n.id, local, protocol, packetEPPkt.Clone())
}
- if protoEPs != nil {
+
+ n.packetEPs.mu.Lock()
+ // Are any packet type sockets listening for this network protocol?
+ protoEPs, protoEPsOK := n.packetEPs.eps[protocol]
+ // Other packet type sockets that are listening for all protocols.
+ anyEPs, anyEPsOK := n.packetEPs.eps[header.EthernetProtocolAll]
+ n.packetEPs.mu.Unlock()
+
+ if protoEPsOK {
protoEPs.forEach(deliverPacketEPs)
}
- if anyEPs != nil {
+ if anyEPsOK {
anyEPs.forEach(deliverPacketEPs)
}
networkEndpoint.HandlePacket(pkt)
}
-// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket.
-func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- n.mu.RLock()
+// deliverOutboundPacket delivers outgoing packets to interested endpoints.
+func (n *nic) deliverOutboundPacket(remote tcpip.LinkAddress, pkt *PacketBuffer) {
+ n.packetEPs.mu.RLock()
+ defer n.packetEPs.mu.RUnlock()
// We do not deliver to protocol specific packet endpoints as on Linux
// only ETH_P_ALL endpoints get outbound packets.
// Add any other packet sockets that maybe listening for all protocols.
- eps := n.mu.packetEPs[header.EthernetProtocolAll]
- n.mu.RUnlock()
+ eps, ok := n.packetEPs.eps[header.EthernetProtocolAll]
+ if !ok {
+ return
+ }
+
+ local := n.LinkAddress()
+ var packetEPPkt *PacketBuffer
eps.forEach(func(ep PacketEndpoint) {
- p := pkt.Clone()
- p.PktType = tcpip.PacketOutgoing
- // Add the link layer header as outgoing packets are intercepted
- // before the link layer header is created.
- n.LinkEndpoint.AddHeader(local, remote, protocol, p)
- ep.HandlePacket(n.id, local, protocol, p)
+ if packetEPPkt == nil {
+ // Packet endpoints hold the full packet.
+ //
+ // We perform a deep copy because higher-level endpoints may point to
+ // the middle of a view that is held by a packet endpoint. Save/Restore
+ // does not support overlapping slices and will panic in this case.
+ //
+ // TODO(https://gvisor.dev/issue/6517): Avoid this copy once S/R supports
+ // overlapping slices (e.g. by passing a shallow copy of pkt to the packet
+ // endpoint).
+ packetEPPkt = NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: pkt.AvailableHeaderBytes(),
+ Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(),
+ })
+ // Add the link layer header as outgoing packets are intercepted before
+ // the link layer header is created and packet endpoints are interested
+ // in the link header.
+ n.LinkEndpoint.AddHeader(local, remote, pkt.NetworkProtocolNumber, packetEPPkt)
+ packetEPPkt.PktType = tcpip.PacketOutgoing
+ }
+
+ ep.HandlePacket(n.id, local, pkt.NetworkProtocolNumber, packetEPPkt.Clone())
})
}
@@ -917,12 +971,13 @@ func (n *nic) setNUDConfigs(protocol tcpip.NetworkProtocolNumber, c NUDConfigura
}
func (n *nic) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error {
- n.mu.Lock()
- defer n.mu.Unlock()
+ n.packetEPs.mu.Lock()
+ defer n.packetEPs.mu.Unlock()
- eps, ok := n.mu.packetEPs[netProto]
+ eps, ok := n.packetEPs.eps[netProto]
if !ok {
- return &tcpip.ErrNotSupported{}
+ eps = new(packetEndpointList)
+ n.packetEPs.eps[netProto] = eps
}
eps.add(ep)
@@ -930,14 +985,17 @@ func (n *nic) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa
}
func (n *nic) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) {
- n.mu.Lock()
- defer n.mu.Unlock()
+ n.packetEPs.mu.Lock()
+ defer n.packetEPs.mu.Unlock()
- eps, ok := n.mu.packetEPs[netProto]
+ eps, ok := n.packetEPs.eps[netProto]
if !ok {
return
}
eps.remove(ep)
+ if eps.len() == 0 {
+ delete(n.packetEPs.eps, netProto)
+ }
}
// isValidForOutgoing returns true if the endpoint can be used to send out a
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 5cb342f78..c8ad93f29 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -127,11 +127,6 @@ func (*testIPv6Protocol) MinimumPacketSize() int {
return header.IPv6MinimumSize
}
-// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen.
-func (*testIPv6Protocol) DefaultPrefixLen() int {
- return header.IPv6AddressSize * 8
-}
-
// ParseAddresses implements NetworkProtocol.ParseAddresses.
func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv6(v)
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 9192d8433..888a8bd9d 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -143,6 +143,8 @@ type PacketBuffer struct {
// NetworkPacketInfo holds an incoming packet's network-layer information.
NetworkPacketInfo NetworkPacketInfo
+
+ tuple *tuple
}
// NewPacketBuffer creates a new PacketBuffer with opts.
@@ -282,14 +284,12 @@ func (pk *PacketBuffer) headerView(typ headerType) tcpipbuffer.View {
return v
}
-// Clone makes a shallow copy of pk.
-//
-// Clone should be called in such cases so that no modifications is done to
-// underlying packet payload.
+// Clone makes a semi-deep copy of pk. The underlying packet payload is
+// shared. Hence, no modifications is done to underlying packet payload.
func (pk *PacketBuffer) Clone() *PacketBuffer {
return &PacketBuffer{
PacketBufferEntry: pk.PacketBufferEntry,
- buf: pk.buf,
+ buf: pk.buf.Clone(),
reserved: pk.reserved,
pushed: pk.pushed,
consumed: pk.consumed,
@@ -304,6 +304,7 @@ func (pk *PacketBuffer) Clone() *PacketBuffer {
NICID: pk.NICID,
RXTransportChecksumValidated: pk.RXTransportChecksumValidated,
NetworkPacketInfo: pk.NetworkPacketInfo,
+ tuple: pk.tuple,
}
}
@@ -321,25 +322,51 @@ func (pk *PacketBuffer) Network() header.Network {
}
}
-// CloneToInbound makes a shallow copy of the packet buffer to be used as an
-// inbound packet.
+// CloneToInbound makes a semi-deep copy of the packet buffer (similar to
+// Clone) to be used as an inbound packet.
//
// See PacketBuffer.Data for details about how a packet buffer holds an inbound
// packet.
func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
newPk := &PacketBuffer{
- buf: pk.buf,
+ buf: pk.buf.Clone(),
// Treat unfilled header portion as reserved.
reserved: pk.AvailableHeaderBytes(),
+ tuple: pk.tuple,
+ }
+ return newPk
+}
+
+// DeepCopyForForwarding creates a deep copy of the packet buffer for
+// forwarding.
+//
+// The returned packet buffer will have the network and transport headers
+// set if the original packet buffer did.
+func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBuffer {
+ newPk := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: reservedHeaderBytes,
+ Data: PayloadSince(pk.NetworkHeader()).ToVectorisedView(),
+ IsForwardedPacket: true,
+ })
+
+ {
+ consumeBytes := pk.NetworkHeader().View().Size()
+ if _, consumed := newPk.NetworkHeader().Consume(consumeBytes); !consumed {
+ panic(fmt.Sprintf("expected to consume network header %d bytes from new packet", consumeBytes))
+ }
+ newPk.NetworkProtocolNumber = pk.NetworkProtocolNumber
}
- // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to
- // maintain this flag in the packet. Currently conntrack needs this flag to
- // tell if a noop connection should be inserted at Input hook. Once conntrack
- // redefines the manipulation field as mutable, we won't need the special noop
- // connection.
- if pk.NatDone {
- newPk.NatDone = true
+
+ {
+ consumeBytes := pk.TransportHeader().View().Size()
+ if _, consumed := newPk.TransportHeader().Consume(consumeBytes); !consumed {
+ panic(fmt.Sprintf("expected to consume transport header %d bytes from new packet", consumeBytes))
+ }
+ newPk.TransportProtocolNumber = pk.TransportProtocolNumber
}
+
+ newPk.tuple = pk.tuple
+
return newPk
}
@@ -391,13 +418,14 @@ func (d PacketData) PullUp(size int) (tcpipbuffer.View, bool) {
return d.pk.buf.PullUp(d.pk.dataOffset(), size)
}
-// DeleteFront removes count from the beginning of d. It panics if count >
-// d.Size(). All backing storage references after the front of the d are
-// invalidated.
-func (d PacketData) DeleteFront(count int) {
- if !d.pk.buf.Remove(d.pk.dataOffset(), count) {
- panic("count > d.Size()")
+// Consume is the same as PullUp except that is additionally consumes the
+// returned bytes. Subsequent PullUp or Consume will not return these bytes.
+func (d PacketData) Consume(size int) (tcpipbuffer.View, bool) {
+ v, ok := d.PullUp(size)
+ if ok {
+ d.pk.consumed += size
}
+ return v, ok
}
// CapLength reduces d to at most length bytes.
diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go
index a8da34992..c376ed1a1 100644
--- a/pkg/tcpip/stack/packet_buffer_test.go
+++ b/pkg/tcpip/stack/packet_buffer_test.go
@@ -435,11 +435,17 @@ func TestPacketBufferData(t *testing.T) {
}
})
- // DeleteFront
+ // Consume.
for _, n := range []int{1, len(tc.data)} {
- t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) {
+ t.Run(fmt.Sprintf("Consume%d", n), func(t *testing.T) {
pkt := tc.makePkt(t)
- pkt.Data().DeleteFront(n)
+ v, ok := pkt.Data().Consume(n)
+ if !ok {
+ t.Fatalf("Consume failed")
+ }
+ if want := []byte(tc.data)[:n]; !bytes.Equal(v, want) {
+ t.Fatalf("pkt.Data().Consume(n) = 0x%x, want 0x%x", v, want)
+ }
checkData(t, pkt, []byte(tc.data)[n:])
})
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index dfe2c886f..31b3a554d 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -318,8 +318,7 @@ type PrimaryEndpointBehavior int
const (
// CanBePrimaryEndpoint indicates the endpoint can be used as a primary
- // endpoint for new connections with no local address. This is the
- // default when calling NIC.AddAddress.
+ // endpoint for new connections with no local address.
CanBePrimaryEndpoint PrimaryEndpointBehavior = iota
// FirstPrimaryEndpoint indicates the endpoint should be the first
@@ -332,6 +331,19 @@ const (
NeverPrimaryEndpoint
)
+func (peb PrimaryEndpointBehavior) String() string {
+ switch peb {
+ case CanBePrimaryEndpoint:
+ return "CanBePrimaryEndpoint"
+ case FirstPrimaryEndpoint:
+ return "FirstPrimaryEndpoint"
+ case NeverPrimaryEndpoint:
+ return "NeverPrimaryEndpoint"
+ default:
+ panic(fmt.Sprintf("unknown primary endpoint behavior: %d", peb))
+ }
+}
+
// AddressConfigType is the method used to add an address.
type AddressConfigType int
@@ -351,6 +363,14 @@ const (
AddressConfigSlaacTemp
)
+// AddressProperties contains additional properties that can be configured when
+// adding an address.
+type AddressProperties struct {
+ PEB PrimaryEndpointBehavior
+ ConfigType AddressConfigType
+ Deprecated bool
+}
+
// AssignableAddressEndpoint is a reference counted address endpoint that may be
// assigned to a NetworkEndpoint.
type AssignableAddressEndpoint interface {
@@ -457,7 +477,7 @@ type AddressableEndpoint interface {
// Returns *tcpip.ErrDuplicateAddress if the address exists.
//
// Acquires and returns the AddressEndpoint for the added address.
- AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error)
+ AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error)
// RemovePermanentAddress removes the passed address if it is a permanent
// address.
@@ -685,9 +705,6 @@ type NetworkProtocol interface {
// than this targeted at this protocol.
MinimumPacketSize() int
- // DefaultPrefixLen returns the protocol's default prefix length.
- DefaultPrefixLen() int
-
// ParseAddresses returns the source and destination addresses stored in a
// packet of this protocol.
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
@@ -733,16 +750,6 @@ type NetworkDispatcher interface {
//
// DeliverNetworkPacket takes ownership of pkt.
DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
-
- // DeliverOutboundPacket is called by link layer when a packet is being
- // sent out.
- //
- // pkt.LinkHeader may or may not be set before calling
- // DeliverOutboundPacket. Some packets do not have link headers (e.g.
- // packets sent via loopback), and won't have the field set.
- //
- // DeliverOutboundPacket takes ownership of pkt.
- DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// LinkEndpointCapabilities is the type associated with the capabilities
@@ -846,6 +853,14 @@ type LinkEndpoint interface {
// offload is enabled. If it will be used for something else, syscall filters
// may need to be updated.
WritePackets(RouteInfo, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error)
+
+ // WriteRawPacket writes a packet directly to the link.
+ //
+ // If the link-layer has its own header, the payload must already include the
+ // header.
+ //
+ // WriteRawPacket takes ownership of the packet.
+ WriteRawPacket(*PacketBuffer) tcpip.Error
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index c73890c4c..428350f31 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -72,7 +72,8 @@ type Stack struct {
// rawFactory creates raw endpoints. If nil, raw endpoints are
// disabled. It is set during Stack creation and is immutable.
- rawFactory RawFactory
+ rawFactory RawFactory
+ packetEndpointWriteSupported bool
demux *transportDemuxer
@@ -119,8 +120,7 @@ type Stack struct {
// by the stack.
icmpRateLimiter *ICMPRateLimiter
- // seed is a one-time random value initialized at stack startup
- // and is used to seed the TCP port picking on active connections
+ // seed is a one-time random value initialized at stack startup.
//
// TODO(gvisor.dev/issue/940): S/R this field.
seed uint32
@@ -161,6 +161,10 @@ type Stack struct {
// This is required to prevent potential ACK loops.
// Setting this to 0 will disable all rate limiting.
tcpInvalidRateLimit time.Duration
+
+ // tsOffsetSecret is the secret key for generating timestamp offsets
+ // initialized at stack startup.
+ tsOffsetSecret uint32
}
// UniqueID is an abstract generator of unique identifiers.
@@ -215,6 +219,10 @@ type Options struct {
// this is non-nil.
RawFactory RawFactory
+ // AllowPacketEndpointWrite determines if packet endpoints support write
+ // operations.
+ AllowPacketEndpointWrite bool
+
// RandSource is an optional source to use to generate random
// numbers. If omitted it defaults to a Source seeded by the data
// returned by the stack secure RNG.
@@ -356,23 +364,24 @@ func New(opts Options) *Stack {
opts.NUDConfigs.resetInvalidFields()
s := &Stack{
- transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
- networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
- nics: make(map[tcpip.NICID]*nic),
- defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}),
- cleanupEndpoints: make(map[TransportEndpoint]struct{}),
- PortManager: ports.NewPortManager(),
- clock: clock,
- stats: opts.Stats.FillIn(),
- handleLocal: opts.HandleLocal,
- tables: opts.IPTables,
- icmpRateLimiter: NewICMPRateLimiter(),
- seed: seed,
- nudConfigs: opts.NUDConfigs,
- uniqueIDGenerator: opts.UniqueID,
- nudDisp: opts.NUDDisp,
- randomGenerator: randomGenerator,
- secureRNG: opts.SecureRNG,
+ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
+ networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
+ nics: make(map[tcpip.NICID]*nic),
+ packetEndpointWriteSupported: opts.AllowPacketEndpointWrite,
+ defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}),
+ cleanupEndpoints: make(map[TransportEndpoint]struct{}),
+ PortManager: ports.NewPortManager(),
+ clock: clock,
+ stats: opts.Stats.FillIn(),
+ handleLocal: opts.HandleLocal,
+ tables: opts.IPTables,
+ icmpRateLimiter: NewICMPRateLimiter(clock),
+ seed: seed,
+ nudConfigs: opts.NUDConfigs,
+ uniqueIDGenerator: opts.UniqueID,
+ nudDisp: opts.NUDDisp,
+ randomGenerator: randomGenerator,
+ secureRNG: opts.SecureRNG,
sendBufferSize: tcpip.SendBufferSizeOption{
Min: MinBufferSize,
Default: DefaultBufferSize,
@@ -384,6 +393,7 @@ func New(opts Options) *Stack {
Max: DefaultMaxBufferSize,
},
tcpInvalidRateLimit: defaultTCPInvalidRateLimit,
+ tsOffsetSecret: randomGenerator.Uint32(),
}
// Add specified network protocols.
@@ -906,46 +916,9 @@ type NICStateFlags struct {
Loopback bool
}
-// AddAddress adds a new network-layer address to the specified NIC.
-func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error {
- return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint)
-}
-
-// AddAddressWithPrefix is the same as AddAddress, but allows you to specify
-// the address prefix.
-func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) tcpip.Error {
- ap := tcpip.ProtocolAddress{
- Protocol: protocol,
- AddressWithPrefix: addr,
- }
- return s.AddProtocolAddressWithOptions(id, ap, CanBePrimaryEndpoint)
-}
-
-// AddProtocolAddress adds a new network-layer protocol address to the
-// specified NIC.
-func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) tcpip.Error {
- return s.AddProtocolAddressWithOptions(id, protocolAddress, CanBePrimaryEndpoint)
-}
-
-// AddAddressWithOptions is the same as AddAddress, but allows you to specify
-// whether the new endpoint can be primary or not.
-func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) tcpip.Error {
- netProto, ok := s.networkProtocols[protocol]
- if !ok {
- return &tcpip.ErrUnknownProtocol{}
- }
- return s.AddProtocolAddressWithOptions(id, tcpip.ProtocolAddress{
- Protocol: protocol,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: netProto.DefaultPrefixLen(),
- },
- }, peb)
-}
-
-// AddProtocolAddressWithOptions is the same as AddProtocolAddress, but allows
-// you to specify whether the new endpoint can be primary or not.
-func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error {
+// AddProtocolAddress adds an address to the specified NIC, possibly with extra
+// properties.
+func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, properties AddressProperties) tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -954,7 +927,7 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc
return &tcpip.ErrUnknownNICID{}
}
- return nic.addAddress(protocolAddress, peb)
+ return nic.addAddress(protocolAddress, properties)
}
// RemoveAddress removes an existing network-layer address from the specified
@@ -1649,9 +1622,27 @@ func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress,
ReserveHeaderBytes: int(nic.MaxHeaderLength()),
Data: payload,
})
+ pkt.NetworkProtocolNumber = netProto
return nic.WritePacketToRemote(remote, netProto, pkt)
}
+// WriteRawPacket writes data directly to the specified NIC without adding any
+// headers.
+func (s *Stack) WriteRawPacket(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) tcpip.Error {
+ s.mu.RLock()
+ nic, ok := s.nics[nicID]
+ s.mu.RUnlock()
+ if !ok {
+ return &tcpip.ErrUnknownNICID{}
+ }
+
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ Data: payload,
+ })
+ pkt.NetworkProtocolNumber = proto
+ return nic.WriteRawPacket(pkt)
+}
+
// NetworkProtocolInstance returns the protocol instance in the stack for the
// specified network protocol. This method is public for protocol implementers
// and tests to use.
@@ -1819,8 +1810,7 @@ func (s *Stack) SetNUDConfigurations(id tcpip.NICID, proto tcpip.NetworkProtocol
return nic.setNUDConfigs(proto, c)
}
-// Seed returns a 32 bit value that can be used as a seed value for port
-// picking, ISN generation etc.
+// Seed returns a 32 bit value that can be used as a seed value.
//
// NOTE: The seed is generated once during stack initialization only.
func (s *Stack) Seed() uint32 {
@@ -1944,3 +1934,9 @@ func (s *Stack) IsSubnetBroadcast(nicID tcpip.NICID, protocol tcpip.NetworkProto
return false
}
+
+// PacketEndpointWriteSupported returns true iff packet endpoints support write
+// operations.
+func (s *Stack) PacketEndpointWriteSupported() bool {
+ return s.packetEndpointWriteSupported
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 3089c0ef4..c23e91702 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -139,18 +139,15 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Handle control packets.
if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
- hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen)
+ hdr, ok := pkt.Data().Consume(fakeNetHeaderLen)
if !ok {
return
}
- // DeleteFront invalidates slices. Make a copy before trimming.
- nb := append([]byte(nil), hdr...)
- pkt.Data().DeleteFront(fakeNetHeaderLen)
f.dispatcher.DeliverTransportError(
- tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]),
- tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]),
+ tcpip.Address(hdr[srcAddrOffset:srcAddrOffset+1]),
+ tcpip.Address(hdr[dstAddrOffset:dstAddrOffset+1]),
fakeNetNumber,
- tcpip.TransportProtocolNumber(nb[protocolNumberOffset]),
+ tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]),
// Nothing checks the error.
nil, /* transport error */
pkt,
@@ -234,10 +231,6 @@ func (*fakeNetworkProtocol) MinimumPacketSize() int {
return fakeNetHeaderLen
}
-func (*fakeNetworkProtocol) DefaultPrefixLen() int {
- return fakeDefaultPrefixLen
-}
-
func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int {
return f.packetCount[int(intfAddr)%len(f.packetCount)]
}
@@ -349,12 +342,26 @@ func TestNetworkReceive(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr2, err)
}
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
@@ -517,8 +524,15 @@ func TestNetworkSend(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Make sure that the link-layer endpoint received the outbound packet.
@@ -538,12 +552,26 @@ func TestNetworkSendMultiRoute(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x03",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err)
}
ep2 := channel.New(10, defaultMTU, "")
@@ -551,12 +579,26 @@ func TestNetworkSendMultiRoute(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr4 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x04",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err)
}
// Set a route table that sends all packets with odd destination
@@ -812,8 +854,15 @@ func TestRouteWithDownNIC(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr1,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err)
}
ep2 := channel.New(1, defaultMTU, "")
@@ -821,8 +870,15 @@ func TestRouteWithDownNIC(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr2,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err)
}
// Set a route table that sends all packets with odd destination
@@ -978,12 +1034,26 @@ func TestRoutes(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x03",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err)
}
ep2 := channel.New(10, defaultMTU, "")
@@ -991,12 +1061,26 @@ func TestRoutes(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr4 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x04",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err)
}
// Set a route table that sends all packets with odd destination
@@ -1058,8 +1142,15 @@ func TestAddressRemoval(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
@@ -1108,8 +1199,15 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
- if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
@@ -1242,8 +1340,15 @@ func TestEndpointExpiration(t *testing.T) {
// 2. Add Address, everything should work.
//-----------------------
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
@@ -1270,8 +1375,8 @@ func TestEndpointExpiration(t *testing.T) {
// 4. Add Address back, everything should work again.
//-----------------------
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
@@ -1310,8 +1415,8 @@ func TestEndpointExpiration(t *testing.T) {
// 7. Add Address back, everything should work again.
//-----------------------
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
@@ -1453,8 +1558,15 @@ func TestExternalSendWithHandleLocal(t *testing.T) {
if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}})
@@ -1510,8 +1622,15 @@ func TestSpoofingWithAddress(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
@@ -1633,8 +1752,8 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
}
protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}}
- if err := s.AddProtocolAddress(1, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err)
+ if err := s.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", protoAddr, err)
}
r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
if err != nil {
@@ -1678,13 +1797,13 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
t.Fatalf("CreateNIC failed: %s", err)
}
nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr}
- if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err)
+ if err := s.AddProtocolAddress(1, nic1ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", nic1ProtoAddr, err)
}
nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr}
- if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil {
- t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err)
+ if err := s.AddProtocolAddress(2, nic2ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(2, %+v, {}) failed: %s", nic2ProtoAddr, err)
}
// Set the initial route table.
@@ -1726,7 +1845,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
// 2. Case: Having an explicit route for broadcast will select that one.
rt = append(
[]tcpip.Route{
- {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1},
+ {Destination: header.IPv4Broadcast.WithPrefix().Subnet(), NIC: 1},
},
rt...,
)
@@ -1808,8 +1927,15 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want)
}
- if err := s.AddAddress(1, fakeNetNumber, anyAddr); err != nil {
- t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, anyAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: anyAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded {
@@ -1886,22 +2012,27 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
// Add an address and in case of a primary one include a
// prefixLen.
address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen))
+ properties := stack.AddressProperties{PEB: behavior}
if behavior == stack.CanBePrimaryEndpoint {
protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: address,
- PrefixLen: addrLen * 8,
- },
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: address.WithPrefix(),
}
- if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, protocolAddress, behavior, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err)
}
// Remember the address/prefix.
primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{}
} else {
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s:", nicID, fakeNetNumber, address, behavior, err)
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err)
}
}
}
@@ -1996,8 +2127,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) {
PrefixLen: tc.prefixLen,
},
}
- if err := s.AddProtocolAddress(1, protocolAddress); err != nil {
- t.Fatal("AddProtocolAddress failed:", err)
+ if err := s.AddProtocolAddress(1, protocolAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", protocolAddress, err)
}
// Check that we get the right initial address and prefix length.
@@ -2047,33 +2178,6 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto
}
}
-func TestAddAddress(t *testing.T) {
- const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- var addrGen addressGenerator
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2)
- for _, addrLen := range []int{4, 16} {
- address := addrGen.next(addrLen)
- if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil {
- t.Fatalf("AddAddress(address=%s) failed: %s", address, err)
- }
- expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen},
- })
- }
-
- gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
-}
-
func TestAddProtocolAddress(t *testing.T) {
const nicID = 1
s := stack.New(stack.Options{
@@ -2084,96 +2188,43 @@ func TestAddProtocolAddress(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- var addrGen addressGenerator
- addrLenRange := []int{4, 16}
- prefixLenRange := []int{8, 13, 20, 32}
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange))
- for _, addrLen := range addrLenRange {
- for _, prefixLen := range prefixLenRange {
- protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addrGen.next(addrLen),
- PrefixLen: prefixLen,
- },
- }
- if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil {
- t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err)
- }
- expectedAddresses = append(expectedAddresses, protocolAddress)
- }
- }
-
- gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
-}
-
-func TestAddAddressWithOptions(t *testing.T) {
- const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
addrLenRange := []int{4, 16}
behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint}
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange))
+ configTypeRange := []stack.AddressConfigType{stack.AddressConfigStatic, stack.AddressConfigSlaac, stack.AddressConfigSlaacTemp}
+ deprecatedRange := []bool{false, true}
+ wantAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)*len(configTypeRange)*len(deprecatedRange))
var addrGen addressGenerator
for _, addrLen := range addrLenRange {
for _, behavior := range behaviorRange {
- address := addrGen.next(addrLen)
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil {
- t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err)
- }
- expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen},
- })
- }
- }
-
- gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
-}
-
-func TestAddProtocolAddressWithOptions(t *testing.T) {
- const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- addrLenRange := []int{4, 16}
- prefixLenRange := []int{8, 13, 20, 32}
- behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint}
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)*len(behaviorRange))
- var addrGen addressGenerator
- for _, addrLen := range addrLenRange {
- for _, prefixLen := range prefixLenRange {
- for _, behavior := range behaviorRange {
- protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addrGen.next(addrLen),
- PrefixLen: prefixLen,
- },
- }
- if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err)
+ for _, configType := range configTypeRange {
+ for _, deprecated := range deprecatedRange {
+ address := addrGen.next(addrLen)
+ properties := stack.AddressProperties{
+ PEB: behavior,
+ ConfigType: configType,
+ Deprecated: deprecated,
+ }
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v) failed: %s", nicID, protocolAddr, properties, err)
+ }
+ wantAddresses = append(wantAddresses, tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen},
+ })
}
- expectedAddresses = append(expectedAddresses, protocolAddress)
}
}
}
gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
+ verifyAddresses(t, wantAddresses, gotAddresses)
}
func TestCreateNICWithOptions(t *testing.T) {
@@ -2290,8 +2341,15 @@ func TestNICStats(t *testing.T) {
if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed: ", err)
}
- if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: nic.addr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicid, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicid, protocolAddr, err)
}
{
@@ -2735,8 +2793,16 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// be returned by a call to GetMainNICAddress;
// else, it should.
const address1 = tcpip.Address("\x01")
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, pi); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err)
+ properties := stack.AddressProperties{PEB: pi}
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address1,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr, properties, err)
}
addr, err := s.GetMainNICAddress(nicID, fakeNetNumber)
if err != nil {
@@ -2785,16 +2851,31 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// Add some other address with peb set to
// FirstPrimaryEndpoint.
const address3 = tcpip.Address("\x03")
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint, err)
-
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address3,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ properties = stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ if err := s.AddProtocolAddress(nicID, protocolAddr3, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr3, properties, err)
}
// Add back the address we removed earlier and
// make sure the new peb was respected.
// (The address should just be promoted now).
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, ps); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address1,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ properties = stack.AddressProperties{PEB: ps}
+ if err := s.AddProtocolAddress(nicID, protocolAddr1, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr1, properties, err)
}
var primaryAddrs []tcpip.Address
for _, pa := range s.NICInfo()[nicID].ProtocolAddresses {
@@ -3096,8 +3177,12 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
}
for _, a := range test.nicAddrs {
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil {
- t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: a.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
@@ -3203,8 +3288,12 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: addr1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// The NIC should have joined addr1's solicited node multicast address.
@@ -3359,8 +3448,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
PrefixLen: 128,
},
}
- if err := s.AddProtocolAddress(nicID, addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err)
+ if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
}
// Address should be in the list of all addresses.
@@ -3687,8 +3776,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
if err := s.CreateNIC(nicID1, ep); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
+ if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err)
}
s.SetRouteTable(test.routes)
@@ -3750,8 +3839,8 @@ func TestResolveWith(t *testing.T) {
PrefixLen: 24,
},
}
- if err := s.AddProtocolAddress(nicID, addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, addr, err)
+ if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
}
s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}})
@@ -3792,8 +3881,15 @@ func TestRouteReleaseAfterAddrRemoval(t *testing.T) {
if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
@@ -3881,8 +3977,8 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) {
PrefixLen: 8,
},
}
- if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddress, err)
}
// Check that we get the right initial address and prefix length.
@@ -3990,44 +4086,44 @@ func TestFindRouteWithForwarding(t *testing.T) {
)
type netCfg struct {
- proto tcpip.NetworkProtocolNumber
- factory stack.NetworkProtocolFactory
- nic1Addr tcpip.Address
- nic2Addr tcpip.Address
- remoteAddr tcpip.Address
+ proto tcpip.NetworkProtocolNumber
+ factory stack.NetworkProtocolFactory
+ nic1AddrWithPrefix tcpip.AddressWithPrefix
+ nic2AddrWithPrefix tcpip.AddressWithPrefix
+ remoteAddr tcpip.Address
}
fakeNetCfg := netCfg{
- proto: fakeNetNumber,
- factory: fakeNetFactory,
- nic1Addr: nic1Addr,
- nic2Addr: nic2Addr,
- remoteAddr: remoteAddr,
+ proto: fakeNetNumber,
+ factory: fakeNetFactory,
+ nic1AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic1Addr, PrefixLen: fakeDefaultPrefixLen},
+ nic2AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic2Addr, PrefixLen: fakeDefaultPrefixLen},
+ remoteAddr: remoteAddr,
}
globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16())
globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16())
ipv6LinkLocalNIC1WithGlobalRemote := netCfg{
- proto: ipv6.ProtocolNumber,
- factory: ipv6.NewProtocol,
- nic1Addr: llAddr1,
- nic2Addr: globalIPv6Addr2,
- remoteAddr: globalIPv6Addr1,
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1AddrWithPrefix: llAddr1.WithPrefix(),
+ nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(),
+ remoteAddr: globalIPv6Addr1,
}
ipv6GlobalNIC1WithLinkLocalRemote := netCfg{
- proto: ipv6.ProtocolNumber,
- factory: ipv6.NewProtocol,
- nic1Addr: globalIPv6Addr1,
- nic2Addr: llAddr1,
- remoteAddr: llAddr2,
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(),
+ nic2AddrWithPrefix: llAddr1.WithPrefix(),
+ remoteAddr: llAddr2,
}
ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{
- proto: ipv6.ProtocolNumber,
- factory: ipv6.NewProtocol,
- nic1Addr: globalIPv6Addr1,
- nic2Addr: globalIPv6Addr2,
- remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(),
+ nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(),
+ remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
}
tests := []struct {
@@ -4036,8 +4132,8 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg netCfg
forwardingEnabled bool
- addrNIC tcpip.NICID
- localAddr tcpip.Address
+ addrNIC tcpip.NICID
+ localAddrWithPrefix tcpip.AddressWithPrefix
findRouteErr tcpip.Error
dependentOnForwarding bool
@@ -4047,7 +4143,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4056,7 +4152,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4065,7 +4161,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4074,7 +4170,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: true,
},
@@ -4083,7 +4179,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4092,7 +4188,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4101,7 +4197,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4110,7 +4206,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4118,7 +4214,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and localAddr on same NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: false,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4126,7 +4222,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and localAddr on same NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: false,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4134,7 +4230,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and localAddr on different NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: false,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4142,7 +4238,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and localAddr on different NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: true,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: true,
},
@@ -4166,7 +4262,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and link-local local addr with route on different NIC",
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: false,
- localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4174,7 +4270,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and link-local local addr with route on same NIC",
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: true,
- localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4182,7 +4278,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with route on same NIC",
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: true,
- localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr,
+ localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4190,7 +4286,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and link-local local addr with route on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4198,7 +4294,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and link-local local addr with route on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4206,7 +4302,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with link-local remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4214,7 +4310,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and global local addr with link-local remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4222,7 +4318,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with link-local multicast remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4230,7 +4326,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and global local addr with link-local multicast remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4238,7 +4334,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with link-local multicast remote on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4246,7 +4342,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and global local addr with link-local multicast remote on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4268,12 +4364,20 @@ func TestFindRouteWithForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err)
}
- if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: test.netCfg.proto,
+ AddressWithPrefix: test.netCfg.nic1AddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err)
}
- if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: test.netCfg.proto,
+ AddressWithPrefix: test.netCfg.nic2AddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err)
}
if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, test.forwardingEnabled); err != nil {
@@ -4282,20 +4386,20 @@ func TestFindRouteWithForwarding(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}})
- r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
+ r, err := s.FindRoute(test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
if err == nil {
defer r.Release()
}
if diff := cmp.Diff(test.findRouteErr, err); diff != "" {
- t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, diff)
+ t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, diff)
}
if test.findRouteErr != nil {
return
}
- if r.LocalAddress() != test.localAddr {
- t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddr)
+ if r.LocalAddress() != test.localAddrWithPrefix.Address {
+ t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddrWithPrefix.Address)
}
if r.RemoteAddress() != test.netCfg.remoteAddr {
t.Errorf("got r.RemoteAddress() = %s, want = %s", r.RemoteAddress(), test.netCfg.remoteAddr)
@@ -4318,8 +4422,8 @@ func TestFindRouteWithForwarding(t *testing.T) {
if !ok {
t.Fatal("packet not sent through ep2")
}
- if pkt.Route.LocalAddress != test.localAddr {
- t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr)
+ if pkt.Route.LocalAddress != test.localAddrWithPrefix.Address {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddrWithPrefix.Address)
}
if pkt.Route.RemoteAddress != test.netCfg.remoteAddr {
t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr)
diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go
index 90a8ba6cf..a941091b0 100644
--- a/pkg/tcpip/stack/tcp.go
+++ b/pkg/tcpip/stack/tcp.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/internal/tcp"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
@@ -288,6 +289,12 @@ type TCPSenderState struct {
// RACKState holds the state related to RACK loss detection algorithm.
RACKState TCPRACKState
+
+ // RetransmitTS records the timestamp used to detect spurious recovery.
+ RetransmitTS uint32
+
+ // SpuriousRecovery indicates if the sender entered recovery spuriously.
+ SpuriousRecovery bool
}
// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint.
@@ -386,6 +393,12 @@ type TCPSndBufState struct {
// SndMTU is the smallest MTU seen in the control packets received.
SndMTU int
+
+ // AutoTuneSndBufDisabled indicates that the auto tuning of send buffer
+ // is disabled.
+ //
+ // Must be accessed using atomic operations.
+ AutoTuneSndBufDisabled uint32
}
// TCPEndpointStateInner contains the members of TCPEndpointState used directly
@@ -396,7 +409,7 @@ type TCPSndBufState struct {
type TCPEndpointStateInner struct {
// TSOffset is a randomized offset added to the value of the TSVal
// field in the timestamp option.
- TSOffset uint32
+ TSOffset tcp.TSOffset
// SACKPermitted is set to true if the peer sends the TCPSACKPermitted
// option in the SYN/SYN-ACK.
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index dda57e225..542d9257c 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -32,11 +32,13 @@ type protocolIDs struct {
// transportEndpoints manages all endpoints of a given protocol. It has its own
// mutex so as to reduce interference between protocols.
type transportEndpoints struct {
- // mu protects all fields of the transportEndpoints.
- mu sync.RWMutex
+ mu sync.RWMutex
+ // +checklocks:mu
endpoints map[TransportEndpointID]*endpointsByNIC
// rawEndpoints contains endpoints for raw sockets, which receive all
// traffic of a given protocol regardless of port.
+ //
+ // +checklocks:mu
rawEndpoints []RawTransportEndpoint
}
@@ -69,7 +71,7 @@ func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint {
// descending order of match quality. If a call to yield returns false,
// iterEndpointsLocked stops iteration and returns immediately.
//
-// Preconditions: eps.mu must be locked.
+// +checklocks:eps.mu
func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) {
// Try to find a match with the id as provided.
if ep, ok := eps.endpoints[id]; ok {
@@ -110,7 +112,7 @@ func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield
// findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in
// descending order of match quality.
//
-// Preconditions: eps.mu must be locked.
+// +checklocks:eps.mu
func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC {
var matchedEPs []*endpointsByNIC
eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
@@ -122,7 +124,7 @@ func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []
// findEndpointLocked returns the endpoint that most closely matches the given id.
//
-// Preconditions: eps.mu must be locked.
+// +checklocks:eps.mu
func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC {
var matchedEP *endpointsByNIC
eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
@@ -133,10 +135,12 @@ func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpo
}
type endpointsByNIC struct {
- mu sync.RWMutex
- endpoints map[tcpip.NICID]*multiPortEndpoint
// seed is a random secret for a jenkins hash.
seed uint32
+
+ mu sync.RWMutex
+ // +checklocks:mu
+ endpoints map[tcpip.NICID]*multiPortEndpoint
}
func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
@@ -171,7 +175,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet
return true
}
// multiPortEndpoints are guaranteed to have at least one element.
- transEP := selectEndpoint(id, mpep, epsByNIC.seed)
+ transEP := mpep.selectEndpoint(id, epsByNIC.seed)
if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
queuedProtocol.QueuePacket(transEP, id, pkt)
epsByNIC.mu.RUnlock()
@@ -200,7 +204,7 @@ func (epsByNIC *endpointsByNIC) handleError(n *nic, id TransportEndpointID, tran
// broadcast like we are doing with handlePacket above?
// multiPortEndpoints are guaranteed to have at least one element.
- selectEndpoint(id, mpep, epsByNIC.seed).HandleError(transErr, pkt)
+ mpep.selectEndpoint(id, epsByNIC.seed).HandleError(transErr, pkt)
}
// registerEndpoint returns true if it succeeds. It fails and returns
@@ -333,15 +337,18 @@ func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber
//
// +stateify savable
type multiPortEndpoint struct {
- mu sync.RWMutex `state:"nosave"`
demux *transportDemuxer
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
+ flags ports.FlagCounter
+
+ mu sync.RWMutex `state:"nosave"`
// endpoints stores the transport endpoints in the order in which they
// were bound. This is required for UDP SO_REUSEADDR.
+ //
+ // +checklocks:mu
endpoints []TransportEndpoint
- flags ports.FlagCounter
}
func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint {
@@ -362,13 +369,16 @@ func reciprocalScale(val, n uint32) uint32 {
// selectEndpoint calculates a hash of destination and source addresses and
// ports then uses it to select a socket. In this case, all packets from one
// address will be sent to same endpoint.
-func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint {
- if len(mpep.endpoints) == 1 {
- return mpep.endpoints[0]
+func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID, seed uint32) TransportEndpoint {
+ ep.mu.RLock()
+ defer ep.mu.RUnlock()
+
+ if len(ep.endpoints) == 1 {
+ return ep.endpoints[0]
}
- if mpep.flags.SharedFlags().ToFlags().Effective().MostRecent {
- return mpep.endpoints[len(mpep.endpoints)-1]
+ if ep.flags.SharedFlags().ToFlags().Effective().MostRecent {
+ return ep.endpoints[len(ep.endpoints)-1]
}
payload := []byte{
@@ -384,8 +394,8 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
h.Write([]byte(id.RemoteAddress))
hash := h.Sum32()
- idx := reciprocalScale(hash, uint32(len(mpep.endpoints)))
- return mpep.endpoints[idx]
+ idx := reciprocalScale(hash, uint32(len(ep.endpoints)))
+ return ep.endpoints[idx]
}
func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) {
@@ -479,7 +489,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
if !ok {
epsByNIC = &endpointsByNIC{
endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
- seed: d.stack.Seed(),
+ seed: d.stack.seed,
}
}
if err := epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice); err != nil {
@@ -657,7 +667,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN
}
}
- ep := selectEndpoint(id, mpep, epsByNIC.seed)
+ ep := mpep.selectEndpoint(id, epsByNIC.seed)
epsByNIC.mu.RUnlock()
return ep
}
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 45b09110d..cd3a8c25a 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -35,7 +35,7 @@ import (
const (
testSrcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testDstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ testDstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
testSrcAddrV4 = "\x0a\x00\x00\x01"
testDstAddrV4 = "\x0a\x00\x00\x02"
@@ -64,12 +64,20 @@ func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICI
}
linkEps[linkEpID] = channelEp
- if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil {
- t.Fatalf("AddAddress IPv4 failed: %s", err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(testDstAddrV4).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(linkEpID, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV4, err)
}
- if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil {
- t.Fatalf("AddAddress IPv6 failed: %s", err)
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: testDstAddrV6.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(linkEpID, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV6, err)
}
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 839178809..655931715 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -357,8 +357,15 @@ func TestTransportReceive(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Create endpoint and connect to remote address.
@@ -428,8 +435,15 @@ func TestTransportControlReceive(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Create endpoint and connect to remote address.
@@ -497,8 +511,15 @@ func TestTransportSend(t *testing.T) {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 55683b4fb..460a6afaf 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -19,7 +19,7 @@
// The starting point is the creation and configuration of a stack. A stack can
// be created by calling the New() function of the tcpip/stack/stack package;
// configuring a stack involves creating NICs (via calls to Stack.CreateNIC()),
-// adding network addresses (via calls to Stack.AddAddress()), and
+// adding network addresses (via calls to Stack.AddProtocolAddress()), and
// setting a route table (via a call to Stack.SetRouteTable()).
//
// Once a stack is configured, endpoints can be created by calling
@@ -423,9 +423,9 @@ type ControlMessages struct {
// HasTimestamp indicates whether Timestamp is valid/set.
HasTimestamp bool
- // Timestamp is the time (in ns) that the last packet used to create
- // the read data was received.
- Timestamp int64
+ // Timestamp is the time that the last packet used to create the read data
+ // was received.
+ Timestamp time.Time `state:".(int64)"`
// HasInq indicates whether Inq is valid/set.
HasInq bool
@@ -451,6 +451,12 @@ type ControlMessages struct {
// PacketInfo holds interface and address data on an incoming packet.
PacketInfo IPPacketInfo
+ // HasIPv6PacketInfo indicates whether IPv6PacketInfo is set.
+ HasIPv6PacketInfo bool
+
+ // IPv6PacketInfo holds interface and address data on an incoming packet.
+ IPv6PacketInfo IPv6PacketInfo
+
// HasOriginalDestinationAddress indicates whether OriginalDstAddress is
// set.
HasOriginalDstAddress bool
@@ -465,10 +471,10 @@ type ControlMessages struct {
// PacketOwner is used to get UID and GID of the packet.
type PacketOwner interface {
- // UID returns KUID of the packet.
+ // KUID returns KUID of the packet.
KUID() uint32
- // GID returns KGID of the packet.
+ // KGID returns KGID of the packet.
KGID() uint32
}
@@ -1164,6 +1170,14 @@ type IPPacketInfo struct {
DestinationAddr Address
}
+// IPv6PacketInfo is the message structure for IPV6_PKTINFO.
+//
+// +stateify savable
+type IPv6PacketInfo struct {
+ Addr Address
+ NIC NICID
+}
+
// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to
// get/set the default, min and max send buffer sizes.
type SendBufferSizeOption struct {
@@ -1231,11 +1245,11 @@ type Route struct {
// String implements the fmt.Stringer interface.
func (r Route) String() string {
var out strings.Builder
- fmt.Fprintf(&out, "%s", r.Destination)
+ _, _ = fmt.Fprintf(&out, "%s", r.Destination)
if len(r.Gateway) > 0 {
- fmt.Fprintf(&out, " via %s", r.Gateway)
+ _, _ = fmt.Fprintf(&out, " via %s", r.Gateway)
}
- fmt.Fprintf(&out, " nic %d", r.NIC)
+ _, _ = fmt.Fprintf(&out, " nic %d", r.NIC)
return out.String()
}
@@ -1255,6 +1269,8 @@ type TransportProtocolNumber uint32
type NetworkProtocolNumber uint32
// A StatCounter keeps track of a statistic.
+//
+// +stateify savable
type StatCounter struct {
count atomicbitops.AlignedAtomicUint64
}
@@ -1270,7 +1286,7 @@ func (s *StatCounter) Decrement() {
}
// Value returns the current value of the counter.
-func (s *StatCounter) Value(name ...string) uint64 {
+func (s *StatCounter) Value(...string) uint64 {
return s.count.Load()
}
@@ -1849,6 +1865,10 @@ type TCPStats struct {
// SegmentsAckedWithDSACK is the number of segments acknowledged with
// DSACK.
SegmentsAckedWithDSACK *StatCounter
+
+ // SpuriousRecovery is the number of times the connection entered loss
+ // recovery spuriously.
+ SpuriousRecovery *StatCounter
}
// UDPStats collects UDP-specific stats.
@@ -1981,6 +2001,8 @@ type Stats struct {
}
// ReceiveErrors collects packet receive errors within transport endpoint.
+//
+// +stateify savable
type ReceiveErrors struct {
// ReceiveBufferOverflow is the number of received packets dropped
// due to the receive buffer being full.
@@ -1998,8 +2020,10 @@ type ReceiveErrors struct {
ChecksumErrors StatCounter
}
-// SendErrors collects packet send errors within the transport layer for
-// an endpoint.
+// SendErrors collects packet send errors within the transport layer for an
+// endpoint.
+//
+// +stateify savable
type SendErrors struct {
// SendToNetworkFailed is the number of packets failed to be written to
// the network endpoint.
@@ -2010,6 +2034,8 @@ type SendErrors struct {
}
// ReadErrors collects segment read errors from an endpoint read call.
+//
+// +stateify savable
type ReadErrors struct {
// ReadClosed is the number of received packet drops because the endpoint
// was shutdown for read.
@@ -2025,6 +2051,8 @@ type ReadErrors struct {
}
// WriteErrors collects packet write errors from an endpoint write call.
+//
+// +stateify savable
type WriteErrors struct {
// WriteClosed is the number of packet drops because the endpoint
// was shutdown for write.
@@ -2040,6 +2068,8 @@ type WriteErrors struct {
}
// TransportEndpointStats collects statistics about the endpoint.
+//
+// +stateify savable
type TransportEndpointStats struct {
// PacketsReceived is the number of successful packet receives.
PacketsReceived StatCounter
diff --git a/pkg/tcpip/tcpip_state.go b/pkg/tcpip/tcpip_state.go
new file mode 100644
index 000000000..1953e24a1
--- /dev/null
+++ b/pkg/tcpip/tcpip_state.go
@@ -0,0 +1,27 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip
+
+import (
+ "time"
+)
+
+func (c *ControlMessages) saveTimestamp() int64 {
+ return c.Timestamp.UnixNano()
+}
+
+func (c *ControlMessages) loadTimestamp(nsec int64) {
+ c.Timestamp = time.Unix(0, nsec)
+}
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 181ef799e..7c998eaae 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -34,12 +34,16 @@ go_test(
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
"//pkg/tcpip/testutil",
+ "//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index 92fa6257d..6e1d4720d 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -473,11 +473,19 @@ func TestMulticastForwarding(t *testing.T) {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv4Addr,
}
- if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err)
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
+ }
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv6Addr,
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
}
if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
@@ -612,8 +620,8 @@ func TestPerInterfaceForwarding(t *testing.T) {
addr: utils.RouterNIC2IPv6Addr,
},
} {
- if err := s.AddProtocolAddress(add.nicID, add.addr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", add.nicID, add.addr, err)
+ if err := s.AddProtocolAddress(add.nicID, add.addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", add.nicID, add.addr, err)
}
}
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index f9ab7d0af..f01e2b128 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -15,19 +15,24 @@
package iptables_test
import (
+ "bytes"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
"gvisor.dev/gvisor/pkg/tcpip/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
)
type inputIfNameMatcher struct {
@@ -49,10 +54,10 @@ const (
nicName = "nic1"
anotherNicName = "nic2"
linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- srcAddrV4 = "\x0a\x00\x00\x01"
- dstAddrV4 = "\x0a\x00\x00\x02"
- srcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- dstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ srcAddrV4 = tcpip.Address("\x0a\x00\x00\x01")
+ dstAddrV4 = tcpip.Address("\x0a\x00\x00\x02")
+ srcAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ dstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
payloadSize = 20
)
@@ -66,8 +71,12 @@ func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) {
if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, dstAddrV6); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, dstAddrV6, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: dstAddrV6.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
return s, e
}
@@ -82,8 +91,12 @@ func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) {
if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
}
- if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, dstAddrV4); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, dstAddrV4, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: dstAddrV4.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
return s, e
}
@@ -601,11 +614,19 @@ func TestIPTableWritePackets(t *testing.T) {
if err := s.CreateNIC(nicID, &e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, srcAddrV6); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, srcAddrV6, err)
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: srcAddrV6.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err)
+ }
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: srcAddrV4.WithPrefix(),
}
- if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, srcAddrV4); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, srcAddrV4, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -856,11 +877,19 @@ func TestForwardingHook(t *testing.T) {
t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv4Addr.Address.WithPrefix(),
}
- if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err)
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
+ }
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv6Addr.Address.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
}
if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
@@ -1037,22 +1066,22 @@ func TestInputHookWithLocalForwarding(t *testing.T) {
if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil {
t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err)
}
- if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv4Addr1, err)
+ if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv4Addr1, err)
}
- if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv6Addr1, err)
+ if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv6Addr1, err)
}
e2 := channel.New(1, header.IPv6MinimumMTU, "")
if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil {
t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
}
- if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv4Addr2, err)
+ if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv4Addr2, err)
}
- if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv6Addr2, err)
+ if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv6Addr2, err)
}
if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
@@ -1132,3 +1161,312 @@ func TestInputHookWithLocalForwarding(t *testing.T) {
})
}
}
+
+func TestSNAT(t *testing.T) {
+ const listenPort = 8080
+
+ type endpointAndAddresses struct {
+ serverEP tcpip.Endpoint
+ serverAddr tcpip.Address
+ serverReadableCH chan struct{}
+
+ clientEP tcpip.Endpoint
+ clientAddr tcpip.Address
+ clientReadableCH chan struct{}
+
+ nattedClientAddr tcpip.Address
+ }
+
+ newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.ReadableEvents)
+ t.Cleanup(func() {
+ wq.EventUnregister(&we)
+ })
+
+ ep, err := s.NewEndpoint(transProto, netProto, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err)
+ }
+ t.Cleanup(ep.Close)
+
+ return ep, ch
+ }
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses
+ }{
+ {
+ name: "IPv4 host1 server with host2 client",
+ netProto: ipv4.ProtocolNumber,
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ ep1, ep1WECH := newEP(t, host1Stack, proto, ipv4.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+
+ nattedClientAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
+ }
+ },
+ },
+ {
+ name: "IPv6 host1 server with host2 client",
+ netProto: ipv6.ProtocolNumber,
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ ep1, ep1WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+
+ nattedClientAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address,
+ }
+ },
+ },
+ }
+
+ subTests := []struct {
+ name string
+ proto tcpip.TransportProtocolNumber
+ expectedConnectErr tcpip.Error
+ setupServer func(t *testing.T, ep tcpip.Endpoint)
+ setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{})
+ needRemoteAddr bool
+ }{
+ {
+ name: "UDP",
+ proto: udp.ProtocolNumber,
+ expectedConnectErr: nil,
+ setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+
+ if err := ep.Connect(clientAddr); err != nil {
+ t.Fatalf("ep.Connect(%#v): %s", clientAddr, err)
+ }
+ return nil, nil
+ },
+ needRemoteAddr: true,
+ },
+ {
+ name: "TCP",
+ proto: tcp.ProtocolNumber,
+ expectedConnectErr: &tcpip.ErrConnectStarted{},
+ setupServer: func(t *testing.T, ep tcpip.Endpoint) {
+ t.Helper()
+
+ if err := ep.Listen(1); err != nil {
+ t.Fatalf("ep.Listen(1): %s", err)
+ }
+ },
+ setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+
+ var addr tcpip.FullAddress
+ for {
+ newEP, wq, err := ep.Accept(&addr)
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ <-ch
+ continue
+ }
+ if err != nil {
+ t.Fatalf("ep.Accept(_): %s", err)
+ }
+ if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath(
+ "NIC",
+ )); diff != "" {
+ t.Errorf("accepted address mismatch (-want +got):\n%s", diff)
+ }
+
+ we, newCH := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.ReadableEvents)
+ return newEP, newCH
+ }
+ },
+ needRemoteAddr: false,
+ },
+ }
+
+ setupNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, target stack.Target) {
+ t.Helper()
+
+ ipv6 := netProto == ipv6.ProtocolNumber
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.NATID, ipv6)
+ ruleIdx := filter.BuiltinChains[stack.Postrouting]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name}
+ filter.Rules[ruleIdx].Target = target
+ // Make sure the packet is not dropped by the next rule.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, filter, ipv6); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err)
+ }
+ }
+
+ natTypes := []struct {
+ name string
+ setupNAT func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber, tcpip.Address)
+ }{
+ {
+ name: "SNAT",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, natToAddr tcpip.Address) {
+ t.Helper()
+
+ setupNAT(t, s, netProto, &stack.SNATTarget{NetworkProtocol: netProto, Addr: natToAddr})
+ },
+ },
+ {
+ name: "Masquerade",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, natToAddr tcpip.Address) {
+ t.Helper()
+
+ setupNAT(t, s, netProto, &stack.MasqueradeTarget{NetworkProtocol: netProto})
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ for _, natType := range natTypes {
+ t.Run(natType.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
+ }
+
+ host1Stack := stack.New(stackOpts)
+ routerStack := stack.New(stackOpts)
+ host2Stack := stack.New(stackOpts)
+ utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack)
+
+ epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto)
+
+ natType.setupNAT(t, routerStack, test.netProto, epsAndAddrs.nattedClientAddr)
+
+ serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort}
+ if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil {
+ t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err)
+ }
+ clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
+ if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
+ }
+
+ if subTest.setupServer != nil {
+ subTest.setupServer(t, epsAndAddrs.serverEP)
+ }
+ {
+ err := epsAndAddrs.clientEP.Connect(serverAddr)
+ if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" {
+ t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff)
+ }
+ }
+ nattedClientAddr := tcpip.FullAddress{Addr: epsAndAddrs.nattedClientAddr}
+ if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err)
+ } else {
+ nattedClientAddr.Port = addr.Port
+ }
+
+ serverEP := epsAndAddrs.serverEP
+ serverCH := epsAndAddrs.serverReadableCH
+ if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, nattedClientAddr); ep != nil {
+ defer ep.Close()
+ serverEP = ep
+ serverCH = ch
+ }
+
+ write := func(ep tcpip.Endpoint, data []byte) {
+ t.Helper()
+
+ var r bytes.Reader
+ r.Reset(data)
+ var wOpts tcpip.WriteOptions
+ n, err := ep.Write(&r, wOpts)
+ if err != nil {
+ t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
+ }
+ if want := int64(len(data)); n != want {
+ t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want)
+ }
+ }
+
+ read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) {
+ t.Helper()
+
+ var buf bytes.Buffer
+ var res tcpip.ReadResult
+ for {
+ var err tcpip.Error
+ opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
+ res, err = ep.Read(&buf, opts)
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ <-ch
+ continue
+ }
+ if err != nil {
+ t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
+ }
+ break
+ }
+
+ readResult := tcpip.ReadResult{
+ Count: len(data),
+ Total: len(data),
+ }
+ if subTest.needRemoteAddr {
+ readResult.RemoteAddr = expectedFrom
+ }
+ if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ )); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ {
+ data := []byte{1, 2, 3, 4}
+ write(epsAndAddrs.clientEP, data)
+ read(serverCH, serverEP, data, nattedClientAddr)
+ }
+
+ {
+ data := []byte{5, 6, 7, 8, 9, 10, 11, 12}
+ write(serverEP, data)
+ read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr)
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index 27caa0c28..95ddd8ec3 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -56,17 +56,17 @@ func setupStack(t *testing.T, stackOpts stack.Options, host1NICID, host2NICID tc
t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
}
- if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv4Addr1, err)
+ if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv4Addr1, err)
}
- if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv4Addr2, err)
+ if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv4Addr2, err)
}
- if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv6Addr1, err)
+ if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv6Addr1, err)
}
- if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv6Addr2, err)
+ if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv6Addr2, err)
}
host1Stack.SetRouteTable([]tcpip.Route{
@@ -568,8 +568,8 @@ func TestForwardingWithLinkResolutionFailure(t *testing.T) {
Protocol: test.networkProtocolNumber,
AddressWithPrefix: test.incomingAddr,
}
- if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingProtoAddr, err)
+ if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingProtoAddr, err)
}
// Set up endpoint through which we will attempt to forward packets.
@@ -582,8 +582,8 @@ func TestForwardingWithLinkResolutionFailure(t *testing.T) {
Protocol: test.networkProtocolNumber,
AddressWithPrefix: test.outgoingAddr,
}
- if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingProtoAddr, err)
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index b2008f0b2..f33223e79 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -195,8 +195,8 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) {
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, test.addAddress, err)
+ if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err)
}
s.SetRouteTable([]tcpip.Route{
{
@@ -290,8 +290,8 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) {
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
{
@@ -431,8 +431,8 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) {
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, test.addAddress, err)
+ if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err)
}
s.SetRouteTable([]tcpip.Route{
{
@@ -693,21 +693,40 @@ func TestExternalLoopbackTraffic(t *testing.T) {
if err := s.CreateNIC(nicID1, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddAddressWithPrefix(nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr, err)
+ v4Addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv4Addr,
}
- if err := s.AddAddressWithPrefix(nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr, err)
+ if err := s.AddProtocolAddress(nicID1, v4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v4Addr, err)
+ }
+ v6Addr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv6Addr,
+ }
+ if err := s.AddProtocolAddress(nicID1, v6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v6Addr, err)
}
if err := s.CreateNIC(nicID2, loopback.New()); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, ipv4Loopback); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, ipv4Loopback, err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: ipv4Loopback,
+ PrefixLen: 8,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
+ }
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: header.IPv6Loopback.WithPrefix(),
}
- if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, header.IPv6Loopback); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, header.IPv6Loopback, err)
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
}
if test.forwarding {
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 2d0a6e6a7..7753e7d6e 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -119,12 +119,12 @@ func TestPingMulticastBroadcast(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: utils.Ipv4Addr}
- if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err)
}
ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: utils.Ipv6Addr}
- if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv6ProtoAddr, err)
+ if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtoAddr, err)
}
// Default routes for IPv4 and IPv6 so ICMP can find a route to the remote
@@ -396,8 +396,8 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
var wq waiter.Queue
@@ -474,8 +474,8 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
PrefixLen: 8,
},
}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -642,8 +642,8 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
// Set the route table so that UDP can find a NIC that is
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
index ac3c703d4..422eb8408 100644
--- a/pkg/tcpip/tests/integration/route_test.go
+++ b/pkg/tcpip/tests/integration/route_test.go
@@ -47,7 +47,10 @@ func TestLocalPing(t *testing.T) {
// request/reply packets.
icmpDataOffset = 8
)
- ipv4Loopback := testutil.MustParse4("127.0.0.1")
+ ipv4Loopback := tcpip.AddressWithPrefix{
+ Address: testutil.MustParse4("127.0.0.1"),
+ PrefixLen: 8,
+ }
channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") }
channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) {
@@ -82,7 +85,7 @@ func TestLocalPing(t *testing.T) {
transProto tcpip.TransportProtocolNumber
netProto tcpip.NetworkProtocolNumber
linkEndpoint func() stack.LinkEndpoint
- localAddr tcpip.Address
+ localAddr tcpip.AddressWithPrefix
icmpBuf func(*testing.T) buffer.View
expectedConnectErr tcpip.Error
checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint)
@@ -101,7 +104,7 @@ func TestLocalPing(t *testing.T) {
transProto: icmp.ProtocolNumber6,
netProto: ipv6.ProtocolNumber,
linkEndpoint: loopback.New,
- localAddr: header.IPv6Loopback,
+ localAddr: header.IPv6Loopback.WithPrefix(),
icmpBuf: ipv6ICMPBuf,
checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
},
@@ -110,7 +113,7 @@ func TestLocalPing(t *testing.T) {
transProto: icmp.ProtocolNumber4,
netProto: ipv4.ProtocolNumber,
linkEndpoint: channelEP,
- localAddr: utils.Ipv4Addr.Address,
+ localAddr: utils.Ipv4Addr,
icmpBuf: ipv4ICMPBuf,
checkLinkEndpoint: channelEPCheck,
},
@@ -119,7 +122,7 @@ func TestLocalPing(t *testing.T) {
transProto: icmp.ProtocolNumber6,
netProto: ipv6.ProtocolNumber,
linkEndpoint: channelEP,
- localAddr: utils.Ipv6Addr.Address,
+ localAddr: utils.Ipv6Addr,
icmpBuf: ipv6ICMPBuf,
checkLinkEndpoint: channelEPCheck,
},
@@ -182,9 +185,13 @@ func TestLocalPing(t *testing.T) {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if len(test.localAddr) != 0 {
- if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err)
+ if len(test.localAddr.Address) != 0 {
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: test.netProto,
+ AddressWithPrefix: test.localAddr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
@@ -197,7 +204,7 @@ func TestLocalPing(t *testing.T) {
}
defer ep.Close()
- connAddr := tcpip.FullAddress{Addr: test.localAddr}
+ connAddr := tcpip.FullAddress{Addr: test.localAddr.Address}
if err := ep.Connect(connAddr); err != test.expectedConnectErr {
t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr)
}
@@ -229,8 +236,8 @@ func TestLocalPing(t *testing.T) {
if diff := cmp.Diff(buffer.View(w.Bytes()[icmpDataOffset:]), payload[icmpDataOffset:]); diff != "" {
t.Errorf("received data mismatch (-want +got):\n%s", diff)
}
- if rr.RemoteAddr.Addr != test.localAddr {
- t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr)
+ if rr.RemoteAddr.Addr != test.localAddr.Address {
+ t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr.Address)
}
test.checkLinkEndpoint(t, e)
@@ -302,11 +309,12 @@ func TestLocalUDP(t *testing.T) {
}
if subTest.addAddress {
- if err := s.AddProtocolAddressWithOptions(nicID, test.canBePrimaryAddr, stack.CanBePrimaryEndpoint); err != nil {
- t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.canBePrimaryAddr, stack.FirstPrimaryEndpoint, err)
+ if err := s.AddProtocolAddress(nicID, test.canBePrimaryAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, test.canBePrimaryAddr, err)
}
- if err := s.AddProtocolAddressWithOptions(nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint, err)
+ properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ if err := s.AddProtocolAddress(nicID, test.firstPrimaryAddr, properties); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, %+v): %s", nicID, test.firstPrimaryAddr, properties, err)
}
}
diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go
index 2e6ae55ea..c69410859 100644
--- a/pkg/tcpip/tests/utils/utils.go
+++ b/pkg/tcpip/tests/utils/utils.go
@@ -40,6 +40,14 @@ const (
Host2NICID = 4
)
+// Common NIC names used by tests.
+const (
+ Host1NICName = "host1NIC"
+ RouterNIC1Name = "routerNIC1"
+ RouterNIC2Name = "routerNIC2"
+ Host2NICName = "host2NIC"
+)
+
// Common link addresses used by tests.
const (
LinkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
@@ -211,17 +219,29 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
host1NIC, routerNIC1 := pipe.New(LinkAddr1, LinkAddr2)
routerNIC2, host2NIC := pipe.New(LinkAddr3, LinkAddr4)
- if err := host1Stack.CreateNIC(Host1NICID, NewEthernetEndpoint(host1NIC)); err != nil {
- t.Fatalf("host1Stack.CreateNIC(%d, _): %s", Host1NICID, err)
+ {
+ opts := stack.NICOptions{Name: Host1NICName}
+ if err := host1Stack.CreateNICWithOptions(Host1NICID, NewEthernetEndpoint(host1NIC), opts); err != nil {
+ t.Fatalf("host1Stack.CreateNICWithOptions(%d, _, %#v): %s", Host1NICID, opts, err)
+ }
}
- if err := routerStack.CreateNIC(RouterNICID1, NewEthernetEndpoint(routerNIC1)); err != nil {
- t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID1, err)
+ {
+ opts := stack.NICOptions{Name: RouterNIC1Name}
+ if err := routerStack.CreateNICWithOptions(RouterNICID1, NewEthernetEndpoint(routerNIC1), opts); err != nil {
+ t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID1, opts, err)
+ }
}
- if err := routerStack.CreateNIC(RouterNICID2, NewEthernetEndpoint(routerNIC2)); err != nil {
- t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID2, err)
+ {
+ opts := stack.NICOptions{Name: RouterNIC2Name}
+ if err := routerStack.CreateNICWithOptions(RouterNICID2, NewEthernetEndpoint(routerNIC2), opts); err != nil {
+ t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID2, opts, err)
+ }
}
- if err := host2Stack.CreateNIC(Host2NICID, NewEthernetEndpoint(host2NIC)); err != nil {
- t.Fatalf("host2Stack.CreateNIC(%d, _): %s", Host2NICID, err)
+ {
+ opts := stack.NICOptions{Name: Host2NICName}
+ if err := host2Stack.CreateNICWithOptions(Host2NICID, NewEthernetEndpoint(host2NIC), opts); err != nil {
+ t.Fatalf("host2Stack.CreateNICWithOptions(%d, _, %#v): %s", Host2NICID, opts, err)
+ }
}
if err := routerStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
@@ -231,29 +251,29 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv6.ProtocolNumber, err)
}
- if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv4Addr, err)
+ if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv4Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv4Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv4Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv4Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv4Addr, err)
}
- if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv4Addr, err)
+ if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv4Addr, err)
}
- if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv6Addr, err)
+ if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv6Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv6Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv6Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv6Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv6Addr, err)
}
- if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv6Addr, err)
+ if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv6Addr, err)
}
host1Stack.SetRouteTable([]tcpip.Route{
diff --git a/pkg/syserror/BUILD b/pkg/tcpip/transport/BUILD
index 76bee5a64..af332ed91 100644
--- a/pkg/syserror/BUILD
+++ b/pkg/tcpip/transport/BUILD
@@ -3,8 +3,11 @@ load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
- name = "syserror",
- srcs = ["syserror.go"],
+ name = "transport",
+ srcs = [
+ "datagram.go",
+ "transport.go",
+ ],
visibility = ["//visibility:public"],
- deps = ["@org_golang_x_sys//unix:go_default_library"],
+ deps = ["//pkg/tcpip"],
)
diff --git a/pkg/tcpip/transport/datagram.go b/pkg/tcpip/transport/datagram.go
new file mode 100644
index 000000000..dfce72c69
--- /dev/null
+++ b/pkg/tcpip/transport/datagram.go
@@ -0,0 +1,49 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// DatagramEndpointState is the state of a datagram-based endpoint.
+type DatagramEndpointState tcpip.EndpointState
+
+// The states a datagram-based endpoint may be in.
+const (
+ _ DatagramEndpointState = iota
+ DatagramEndpointStateInitial
+ DatagramEndpointStateBound
+ DatagramEndpointStateConnected
+ DatagramEndpointStateClosed
+)
+
+// String implements fmt.Stringer.
+func (s DatagramEndpointState) String() string {
+ switch s {
+ case DatagramEndpointStateInitial:
+ return "INITIAL"
+ case DatagramEndpointStateBound:
+ return "BOUND"
+ case DatagramEndpointStateConnected:
+ return "CONNECTED"
+ case DatagramEndpointStateClosed:
+ return "CLOSED"
+ default:
+ panic(fmt.Sprintf("unhandled %[1]T variant = %[1]d", s))
+ }
+}
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index bbc0e3ecc..4718ec4ec 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -33,6 +33,8 @@ go_library(
"//pkg/tcpip/header",
"//pkg/tcpip/ports",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
+ "//pkg/tcpip/transport/internal/network",
"//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
"//pkg/waiter",
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index f9a15efb2..31579a896 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -15,6 +15,7 @@
package icmp
import (
+ "fmt"
"io"
"time"
@@ -24,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -35,15 +38,6 @@ type icmpPacket struct {
receivedAt time.Time `state:".(int64)"`
}
-type endpointState int
-
-const (
- stateInitial endpointState = iota
- stateBound
- stateConnected
- stateClosed
-)
-
// endpoint represents an ICMP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -51,14 +45,17 @@ const (
//
// +stateify savable
type endpoint struct {
- stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
+ transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
uniqueID uint64
+ net network.Endpoint
+ stats tcpip.TransportEndpointStats
+ ops tcpip.SocketOptions
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -70,38 +67,23 @@ type endpoint struct {
// The following fields are protected by the mu mutex.
mu sync.RWMutex `state:"nosave"`
- // shutdownFlags represent the current shutdown state of the endpoint.
- shutdownFlags tcpip.ShutdownFlags
- state endpointState
- route *stack.Route `state:"manual"`
- ttl uint8
- stats tcpip.TransportEndpointStats `state:"nosave"`
-
- // owner is used to get uid and gid of the packet.
- owner tcpip.PacketOwner
-
- // ops is used to get socket level options.
- ops tcpip.SocketOptions
-
// frozen indicates if the packets should be delivered to the endpoint
// during restore.
frozen bool
+ ident uint16
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
ep := &endpoint{
- stack: s,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- TransProto: transProto,
- },
+ stack: s,
+ transProto: transProto,
waiterQueue: waiterQueue,
- state: stateInitial,
uniqueID: s.UniqueID(),
}
ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
ep.ops.SetSendBufferSize(32*1024, false /* notify */)
ep.ops.SetReceiveBufferSize(32*1024, false /* notify */)
+ ep.net.Init(s, netProto, transProto, &ep.ops)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -128,35 +110,40 @@ func (e *endpoint) Abort() {
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
- e.mu.Lock()
- e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
- switch e.state {
- case stateBound, stateConnected:
- bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
- e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, bindToDevice)
- }
-
- // Close the receive list and drain it.
- e.rcvMu.Lock()
- e.rcvClosed = true
- e.rcvBufSize = 0
- for !e.rcvList.Empty() {
- p := e.rcvList.Front()
- e.rcvList.Remove(p)
- }
- e.rcvMu.Unlock()
+ notify := func() bool {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateClosed:
+ return false
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ info := e.net.Info()
+ info.ID.LocalPort = e.ident
+ e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{info.NetProto}, e.transProto, info.ID, e, ports.Flags{}, tcpip.NICID(e.ops.GetBindToDevice()))
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
+ }
- if e.route != nil {
- e.route.Release()
- e.route = nil
- }
+ e.net.Shutdown()
+ e.net.Close()
- // Update the state.
- e.state = stateClosed
+ e.rcvMu.Lock()
+ defer e.rcvMu.Unlock()
+ e.rcvClosed = true
+ e.rcvBufSize = 0
+ for !e.rcvList.Empty() {
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ }
- e.mu.Unlock()
+ return true
+ }()
- e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
+ if notify {
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
+ }
}
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
@@ -164,7 +151,7 @@ func (*endpoint) ModerateRecvBuf(int) {}
// SetOwner implements tcpip.Endpoint.SetOwner.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
- e.owner = owner
+ e.net.SetOwner(owner)
}
// Read implements tcpip.Endpoint.Read.
@@ -193,7 +180,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
Total: p.data.Size(),
ControlMessages: tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: p.receivedAt.UnixNano(),
+ Timestamp: p.receivedAt,
},
}
if opts.NeedRemoteAddr {
@@ -214,13 +201,12 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
//
// Returns true for retry if preparation should be retried.
// +checklocks:e.mu
-func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
- switch e.state {
- case stateInitial:
- case stateConnected:
+func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
+ switch e.net.State() {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateConnected:
return false, nil
-
- case stateBound:
+ case transport.DatagramEndpointStateBound:
if to == nil {
return false, &tcpip.ErrDestinationRequired{}
}
@@ -235,7 +221,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
- if e.state != stateInitial {
+ if e.net.State() != transport.DatagramEndpointStateInitial {
return true, nil
}
@@ -270,27 +256,15 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return n, err
}
-func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
- // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
- if opts.More {
- return 0, &tcpip.ErrInvalidOptionValue{}
- }
-
- to := opts.To
-
+func (e *endpoint) prepareForWrite(opts tcpip.WriteOptions) (network.WriteContext, uint16, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- // If we've shutdown with SHUT_WR we are in an invalid state for sending.
- if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
- return 0, &tcpip.ErrClosedForSend{}
- }
-
// Prepare for write.
for {
- retry, err := e.prepareForWrite(to)
+ retry, err := e.prepareForWriteInner(opts.To)
if err != nil {
- return 0, err
+ return network.WriteContext{}, 0, err
}
if !retry {
@@ -298,53 +272,35 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
}
}
- route := e.route
- if to != nil {
- // Reject destination address if it goes through a different
- // NIC than the endpoint was bound to.
- nicID := to.NIC
- if nicID == 0 {
- nicID = tcpip.NICID(e.ops.GetBindToDevice())
- }
- if e.BindNICID != 0 {
- if nicID != 0 && nicID != e.BindNICID {
- return 0, &tcpip.ErrNoRoute{}
- }
-
- nicID = e.BindNICID
- }
-
- dst, netProto, err := e.checkV4MappedLocked(*to)
- if err != nil {
- return 0, err
- }
-
- // Find the endpoint.
- r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */)
- if err != nil {
- return 0, err
- }
- defer r.Release()
+ ctx, err := e.net.AcquireContextForWrite(opts)
+ return ctx, e.ident, err
+}
- route = r
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
+ ctx, ident, err := e.prepareForWrite(opts)
+ if err != nil {
+ return 0, err
}
+ defer ctx.Release()
+ // TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
v := make([]byte, p.Len())
if _, err := io.ReadFull(p, v); err != nil {
return 0, &tcpip.ErrBadBuffer{}
}
- var err tcpip.Error
- switch e.NetProto {
+ switch netProto, pktInfo := e.net.NetProto(), ctx.PacketInfo(); netProto {
case header.IPv4ProtocolNumber:
- err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner)
+ if err := send4(e.stack, &ctx, ident, v, pktInfo.MaxHeaderLength); err != nil {
+ return 0, err
+ }
case header.IPv6ProtocolNumber:
- err = send6(route, e.ID.LocalPort, v, e.ttl)
- }
-
- if err != nil {
- return 0, err
+ if err := send6(e.stack, &ctx, ident, v, pktInfo.LocalAddress, pktInfo.RemoteAddress, pktInfo.MaxHeaderLength); err != nil {
+ return 0, err
+ }
+ default:
+ panic(fmt.Sprintf("unhandled network protocol = %d", netProto))
}
return int64(len(v)), nil
@@ -357,24 +313,17 @@ func (e *endpoint) HasNIC(id int32) bool {
return e.stack.HasNIC(tcpip.NICID(id))
}
-// SetSockOpt sets a socket option.
-func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error {
- return nil
+// SetSockOpt implements tcpip.Endpoint.
+func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
+ return e.net.SetSockOpt(opt)
}
-// SetSockOptInt sets a socket option. Currently not supported.
+// SetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
- switch opt {
- case tcpip.TTLOption:
- e.mu.Lock()
- e.ttl = uint8(v)
- e.mu.Unlock()
-
- }
- return nil
+ return e.net.SetSockOptInt(opt, v)
}
-// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+// GetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
@@ -387,31 +336,24 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- case tcpip.TTLOption:
- e.rcvMu.Lock()
- v := int(e.ttl)
- e.rcvMu.Unlock()
- return v, nil
-
default:
- return -1, &tcpip.ErrUnknownProtocolOption{}
+ return e.net.GetSockOptInt(opt)
}
}
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error {
- return &tcpip.ErrUnknownProtocolOption{}
+// GetSockOpt implements tcpip.Endpoint.
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+ return e.net.GetSockOpt(opt)
}
-func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) tcpip.Error {
+func send4(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, maxHeaderLength uint16) tcpip.Error {
if len(data) < header.ICMPv4MinimumSize {
return &tcpip.ErrInvalidEndpointState{}
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.ICMPv4MinimumSize + int(r.MaxHeaderLength()),
+ ReserveHeaderBytes: header.ICMPv4MinimumSize + int(maxHeaderLength),
})
- pkt.Owner = owner
icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize))
pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
@@ -426,36 +368,31 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
return &tcpip.ErrInvalidEndpointState{}
}
- // Because this icmp endpoint is implemented in the transport layer, we can
- // only increment the 'stack-wide' stats but we can't increment the
- // 'per-NetworkEndpoint' stats.
- sentStat := r.Stats().ICMP.V4.PacketsSent.EchoRequest
-
icmpv4.SetChecksum(0)
icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
-
pkt.Data().AppendView(data)
- if ttl == 0 {
- ttl = r.DefaultTTL()
- }
+ // Because this icmp endpoint is implemented in the transport layer, we can
+ // only increment the 'stack-wide' stats but we can't increment the
+ // 'per-NetworkEndpoint' stats.
+ stats := s.Stats().ICMP.V4.PacketsSent
- if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil {
- r.Stats().ICMP.V4.PacketsSent.Dropped.Increment()
+ if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil {
+ stats.Dropped.Increment()
return err
}
- sentStat.Increment()
+ stats.EchoRequest.Increment()
return nil
}
-func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Error {
+func send6(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, src, dst tcpip.Address, maxHeaderLength uint16) tcpip.Error {
if len(data) < header.ICMPv6EchoMinimumSize {
return &tcpip.ErrInvalidEndpointState{}
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.ICMPv6MinimumSize + int(r.MaxHeaderLength()),
+ ReserveHeaderBytes: header.ICMPv6MinimumSize + int(maxHeaderLength),
})
icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize))
@@ -468,43 +405,31 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro
if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 {
return &tcpip.ErrInvalidEndpointState{}
}
- // Because this icmp endpoint is implemented in the transport layer, we can
- // only increment the 'stack-wide' stats but we can't increment the
- // 'per-NetworkEndpoint' stats.
- sentStat := r.Stats().ICMP.V6.PacketsSent.EchoRequest
pkt.Data().AppendView(data)
dataRange := pkt.Data().AsRange()
icmpv6.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpv6,
- Src: r.LocalAddress(),
- Dst: r.RemoteAddress(),
+ Src: src,
+ Dst: dst,
PayloadCsum: dataRange.Checksum(),
PayloadLen: dataRange.Size(),
}))
- if ttl == 0 {
- ttl = r.DefaultTTL()
- }
+ // Because this icmp endpoint is implemented in the transport layer, we can
+ // only increment the 'stack-wide' stats but we can't increment the
+ // 'per-NetworkEndpoint' stats.
+ stats := s.Stats().ICMP.V6.PacketsSent
- if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil {
- r.Stats().ICMP.V6.PacketsSent.Dropped.Increment()
+ if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil {
+ stats.Dropped.Increment()
+ return err
}
- sentStat.Increment()
+ stats.EchoRequest.Increment()
return nil
}
-// checkV4MappedLocked determines the effective network protocol and converts
-// addr to its canonical form.
-func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */)
- if err != nil {
- return tcpip.FullAddress{}, 0, err
- }
- return unwrapped, netProto, nil
-}
-
// Disconnect implements tcpip.Endpoint.Disconnect.
func (*endpoint) Disconnect() tcpip.Error {
return &tcpip.ErrNotSupported{}
@@ -515,59 +440,21 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- nicID := addr.NIC
- localPort := uint16(0)
- switch e.state {
- case stateInitial:
- case stateBound, stateConnected:
- localPort = e.ID.LocalPort
- if e.BindNICID == 0 {
- break
- }
+ err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error {
+ nextID.LocalPort = e.ident
- if nicID != 0 && nicID != e.BindNICID {
- return &tcpip.ErrInvalidEndpointState{}
+ nextID, err := e.registerWithStack(netProto, nextID)
+ if err != nil {
+ return err
}
- nicID = e.BindNICID
- default:
- return &tcpip.ErrInvalidEndpointState{}
- }
-
- addr, netProto, err := e.checkV4MappedLocked(addr)
- if err != nil {
- return err
- }
-
- // Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicID, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */)
- if err != nil {
- return err
- }
-
- id := stack.TransportEndpointID{
- LocalAddress: r.LocalAddress(),
- LocalPort: localPort,
- RemoteAddress: r.RemoteAddress(),
- }
-
- // Even if we're connected, this endpoint can still be used to send
- // packets on a different network protocol, so we register both even if
- // v6only is set to false and this is an ipv6 endpoint.
- netProtos := []tcpip.NetworkProtocolNumber{netProto}
-
- id, err = e.registerWithStack(nicID, netProtos, id)
+ e.ident = nextID.LocalPort
+ return nil
+ })
if err != nil {
- r.Release()
return err
}
- e.ID = id
- e.route = r
- e.RegisterNICID = nicID
-
- e.state = stateConnected
-
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
@@ -585,10 +472,19 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error {
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- e.shutdownFlags |= flags
- if e.state != stateConnected {
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
return &tcpip.ErrNotConnected{}
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
+ }
+
+ if flags&tcpip.ShutdownWrite != 0 {
+ if err := e.net.Shutdown(); err != nil {
+ return err
+ }
}
if flags&tcpip.ShutdownRead != 0 {
@@ -615,19 +511,18 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
return nil, nil, &tcpip.ErrNotSupported{}
}
-func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) {
+func (e *endpoint) registerWithStack(netProto tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) {
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice)
- return id, err
+ return id, e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice)
}
// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice)
+ err := e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice)
switch err.(type) {
case nil:
return true, nil
@@ -644,42 +539,27 @@ func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkPro
func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.state != stateInitial {
+ if e.net.State() != transport.DatagramEndpointStateInitial {
return &tcpip.ErrInvalidEndpointState{}
}
- addr, netProto, err := e.checkV4MappedLocked(addr)
- if err != nil {
- return err
- }
-
- // Expand netProtos to include v4 and v6 if the caller is binding to a
- // wildcard (empty) address, and this is an IPv6 endpoint with v6only
- // set to false.
- netProtos := []tcpip.NetworkProtocolNumber{netProto}
-
- if len(addr.Addr) != 0 {
- // A local address was specified, verify that it's valid.
- if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 {
- return &tcpip.ErrBadLocalAddress{}
+ err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error {
+ id := stack.TransportEndpointID{
+ LocalPort: addr.Port,
+ LocalAddress: addr.Addr,
+ }
+ id, err := e.registerWithStack(boundNetProto, id)
+ if err != nil {
+ return err
}
- }
- id := stack.TransportEndpointID{
- LocalPort: addr.Port,
- LocalAddress: addr.Addr,
- }
- id, err = e.registerWithStack(addr.NIC, netProtos, id)
+ e.ident = id.LocalPort
+ return nil
+ })
if err != nil {
return err
}
- e.ID = id
- e.RegisterNICID = addr.NIC
-
- // Mark endpoint as bound.
- e.state = stateBound
-
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
@@ -687,21 +567,24 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
return nil
}
+func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, addr tcpip.Address) bool {
+ return addr == header.IPv4Broadcast ||
+ header.IsV4MulticastAddress(addr) ||
+ header.IsV6MulticastAddress(addr) ||
+ e.stack.IsSubnetBroadcast(nicID, e.net.NetProto(), addr)
+}
+
// Bind binds the endpoint to a specific local address and port.
// Specifying a NIC is optional.
func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- err := e.bindLocked(addr)
- if err != nil {
- return err
+ if len(addr.Addr) != 0 && e.isBroadcastOrMulticast(addr.NIC, addr.Addr) {
+ return &tcpip.ErrBadLocalAddress{}
}
- e.BindNICID = addr.NIC
- e.BindAddr = addr.Addr
+ e.mu.Lock()
+ defer e.mu.Unlock()
- return nil
+ return e.bindLocked(addr)
}
// GetLocalAddress returns the address to which the endpoint is bound.
@@ -709,11 +592,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
- }, nil
+ addr := e.net.GetLocalAddress()
+ addr.Port = e.ident
+ return addr, nil
}
// GetRemoteAddress returns the address to which the endpoint is connected.
@@ -721,15 +602,11 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.state != stateConnected {
- return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
+ if addr, connected := e.net.GetRemoteAddress(); connected {
+ return addr, nil
}
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: e.ID.RemoteAddress,
- Port: e.ID.RemotePort,
- }, nil
+ return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
}
// Readiness returns the current readiness of the endpoint. For example, if
@@ -754,7 +631,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// endpoint.
func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Only accept echo replies.
- switch e.NetProto {
+ switch e.net.NetProto() {
case header.IPv4ProtocolNumber:
h := header.ICMPv4(pkt.TransportHeader().View())
if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply {
@@ -828,9 +705,9 @@ func (e *endpoint) State() uint32 {
// Info returns a copy of the endpoint info.
func (e *endpoint) Info() tcpip.EndpointInfo {
e.mu.RLock()
- // Make a copy of the endpoint info.
- ret := e.TransportEndpointInfo
- e.mu.RUnlock()
+ defer e.mu.RUnlock()
+ ret := e.net.Info()
+ ret.ID.LocalPort = e.ident
return &ret
}
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
index b8b839e4a..dfe453ff9 100644
--- a/pkg/tcpip/transport/icmp/endpoint_state.go
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -15,11 +15,13 @@
package icmp
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
)
// saveReceivedAt is invoked by stateify.
@@ -61,29 +63,24 @@ func (e *endpoint) beforeSave() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
e.thaw()
+
+ e.net.Resume(s)
+
e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
- if e.state != stateBound && e.state != stateConnected {
- return
- }
-
- var err tcpip.Error
- if e.state == stateConnected {
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */)
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ var err tcpip.Error
+ info := e.net.Info()
+ info.ID.LocalPort = e.ident
+ info.ID, err = e.registerWithStack(info.NetProto, info.ID)
if err != nil {
- panic(err)
+ panic(fmt.Sprintf("e.registerWithStack(%d, %#v): %s", info.NetProto, info.ID, err))
}
-
- e.ID.LocalAddress = e.route.LocalAddress()
- } else if len(e.ID.LocalAddress) != 0 { // stateBound
- if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 {
- panic(&tcpip.ErrBadLocalAddress{})
- }
- }
-
- e.ID, err = e.registerWithStack(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.ID)
- if err != nil {
- panic(err)
+ e.ident = info.ID.LocalPort
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
}
diff --git a/pkg/tcpip/transport/icmp/icmp_test.go b/pkg/tcpip/transport/icmp/icmp_test.go
index cc950cbde..729f50e9a 100644
--- a/pkg/tcpip/transport/icmp/icmp_test.go
+++ b/pkg/tcpip/transport/icmp/icmp_test.go
@@ -55,8 +55,12 @@ func addNICWithDefaultRoute(t *testing.T, s *stack.Stack, id tcpip.NICID, name s
t.Fatalf("s.CreateNIC(%d, _) = %s", id, err)
}
- if err := s.AddAddress(id, ipv4.ProtocolNumber, addrV4); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s) = %s", id, ipv4.ProtocolNumber, addrV4, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addrV4.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err)
}
s.AddRoute(tcpip.Route{
diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD
new file mode 100644
index 000000000..3818cb04e
--- /dev/null
+++ b/pkg/tcpip/transport/internal/network/BUILD
@@ -0,0 +1,46 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "network",
+ srcs = [
+ "endpoint.go",
+ "endpoint_state.go",
+ ],
+ visibility = [
+ "//pkg/tcpip/transport/icmp:__pkg__",
+ "//pkg/tcpip/transport/raw:__pkg__",
+ "//pkg/tcpip/transport/udp:__pkg__",
+ ],
+ deps = [
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
+ ],
+)
+
+go_test(
+ name = "network_test",
+ size = "small",
+ srcs = ["endpoint_test.go"],
+ deps = [
+ ":network",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
+ "//pkg/tcpip/transport",
+ "//pkg/tcpip/transport/udp",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go
new file mode 100644
index 000000000..e3094f59f
--- /dev/null
+++ b/pkg/tcpip/transport/internal/network/endpoint.go
@@ -0,0 +1,811 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package network provides facilities to support tcpip.Endpoints that operate
+// at the network layer or above.
+package network
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+)
+
+// Endpoint is a datagram-based endpoint. It only supports sending datagrams to
+// a peer.
+//
+// +stateify savable
+type Endpoint struct {
+ // The following fields must only be set once then never changed.
+ stack *stack.Stack `state:"manual"`
+ ops *tcpip.SocketOptions
+ netProto tcpip.NetworkProtocolNumber
+ transProto tcpip.TransportProtocolNumber
+
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
+ wasBound bool
+ // owner is the owner of transmitted packets.
+ //
+ // +checklocks:mu
+ owner tcpip.PacketOwner
+ // +checklocks:mu
+ writeShutdown bool
+ // +checklocks:mu
+ effectiveNetProto tcpip.NetworkProtocolNumber
+ // +checklocks:mu
+ connectedRoute *stack.Route `state:"manual"`
+ // +checklocks:mu
+ multicastMemberships map[multicastMembership]struct{}
+ // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
+ ttl uint8
+ // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
+ multicastTTL uint8
+ // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
+ multicastAddr tcpip.Address
+ // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
+ multicastNICID tcpip.NICID
+ // +checklocks:mu
+ ipv4TOS uint8
+ // +checklocks:mu
+ ipv6TClass uint8
+
+ // Lock ordering: mu > infoMu.
+ infoMu sync.RWMutex `state:"nosave"`
+ // info has a dedicated mutex so that we can avoid lock ordering violations
+ // when reading the endpoint's info. If we used mu, we need to guarantee
+ // that any lock taken while mu is held is not held when calling Info()
+ // which is not true as of writing (we hold mu while registering transport
+ // endpoints (taking the transport demuxer lock but we also hold the demuxer
+ // lock when delivering packets/errors to endpoints).
+ //
+ // Writes must be performed through setInfo.
+ //
+ // +checklocks:infoMu
+ info stack.TransportEndpointInfo
+
+ // state holds a transport.DatagramBasedEndpointState.
+ //
+ // state must be accessed with atomics so that we can avoid lock ordering
+ // violations when reading the state. If we used mu, we need to guarantee
+ // that any lock taken while mu is held is not held when calling State()
+ // which is not true as of writing (we hold mu while registering transport
+ // endpoints (taking the transport demuxer lock but we also hold the demuxer
+ // lock when delivering packets/errors to endpoints).
+ //
+ // Writes must be performed through setEndpointState.
+ //
+ // +checkatomics
+ state uint32
+}
+
+// +stateify savable
+type multicastMembership struct {
+ nicID tcpip.NICID
+ multicastAddr tcpip.Address
+}
+
+// Init initializes the endpoint.
+func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ops *tcpip.SocketOptions) {
+ e.mu.Lock()
+ memberships := e.multicastMemberships
+ e.mu.Unlock()
+ if memberships != nil {
+ panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", memberships))
+ }
+
+ switch netProto {
+ case header.IPv4ProtocolNumber, header.IPv6ProtocolNumber:
+ default:
+ panic(fmt.Sprintf("invalid protocol number = %d", netProto))
+ }
+
+ *e = Endpoint{
+ stack: s,
+ ops: ops,
+ netProto: netProto,
+ transProto: transProto,
+
+ info: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ TransProto: transProto,
+ },
+ effectiveNetProto: netProto,
+ // Linux defaults to TTL=1.
+ multicastTTL: 1,
+ multicastMemberships: make(map[multicastMembership]struct{}),
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.setEndpointState(transport.DatagramEndpointStateInitial)
+}
+
+// NetProto returns the network protocol the endpoint was initialized with.
+func (e *Endpoint) NetProto() tcpip.NetworkProtocolNumber {
+ return e.netProto
+}
+
+// setEndpointState sets the state of the endpoint.
+//
+// e.mu must be held to synchronize changes to state with the rest of the
+// endpoint.
+//
+// +checklocks:e.mu
+func (e *Endpoint) setEndpointState(state transport.DatagramEndpointState) {
+ atomic.StoreUint32(&e.state, uint32(state))
+}
+
+// State returns the state of the endpoint.
+func (e *Endpoint) State() transport.DatagramEndpointState {
+ return transport.DatagramEndpointState(atomic.LoadUint32(&e.state))
+}
+
+// Close cleans the endpoint's resources and leaves the endpoint in a closed
+// state.
+func (e *Endpoint) Close() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if e.State() == transport.DatagramEndpointStateClosed {
+ return
+ }
+
+ for mem := range e.multicastMemberships {
+ e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr)
+ }
+ e.multicastMemberships = nil
+
+ if e.connectedRoute != nil {
+ e.connectedRoute.Release()
+ e.connectedRoute = nil
+ }
+
+ e.setEndpointState(transport.DatagramEndpointStateClosed)
+}
+
+// SetOwner sets the owner of transmitted packets.
+func (e *Endpoint) SetOwner(owner tcpip.PacketOwner) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.owner = owner
+}
+
+func calculateTTL(route *stack.Route, ttl uint8, multicastTTL uint8) uint8 {
+ if header.IsV4MulticastAddress(route.RemoteAddress()) || header.IsV6MulticastAddress(route.RemoteAddress()) {
+ return multicastTTL
+ }
+
+ if ttl == 0 {
+ return route.DefaultTTL()
+ }
+
+ return ttl
+}
+
+// WriteContext holds the context for a write.
+type WriteContext struct {
+ transProto tcpip.TransportProtocolNumber
+ route *stack.Route
+ ttl uint8
+ tos uint8
+ owner tcpip.PacketOwner
+}
+
+// Release releases held resources.
+func (c *WriteContext) Release() {
+ c.route.Release()
+ *c = WriteContext{}
+}
+
+// WritePacketInfo is the properties of a packet that may be written.
+type WritePacketInfo struct {
+ NetProto tcpip.NetworkProtocolNumber
+ LocalAddress, RemoteAddress tcpip.Address
+ MaxHeaderLength uint16
+ RequiresTXTransportChecksum bool
+}
+
+// PacketInfo returns the properties of a packet that will be written.
+func (c *WriteContext) PacketInfo() WritePacketInfo {
+ return WritePacketInfo{
+ NetProto: c.route.NetProto(),
+ LocalAddress: c.route.LocalAddress(),
+ RemoteAddress: c.route.RemoteAddress(),
+ MaxHeaderLength: c.route.MaxHeaderLength(),
+ RequiresTXTransportChecksum: c.route.RequiresTXTransportChecksum(),
+ }
+}
+
+// WritePacket attempts to write the packet.
+func (c *WriteContext) WritePacket(pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error {
+ pkt.Owner = c.owner
+
+ if headerIncluded {
+ return c.route.WriteHeaderIncludedPacket(pkt)
+ }
+
+ return c.route.WritePacket(stack.NetworkHeaderParams{
+ Protocol: c.transProto,
+ TTL: c.ttl,
+ TOS: c.tos,
+ }, pkt)
+}
+
+// AcquireContextForWrite acquires a WriteContext.
+func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext, tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
+ if opts.More {
+ return WriteContext{}, &tcpip.ErrInvalidOptionValue{}
+ }
+
+ if e.State() == transport.DatagramEndpointStateClosed {
+ return WriteContext{}, &tcpip.ErrInvalidEndpointState{}
+ }
+
+ if e.writeShutdown {
+ return WriteContext{}, &tcpip.ErrClosedForSend{}
+ }
+
+ route := e.connectedRoute
+ if opts.To == nil {
+ // If the user doesn't specify a destination, they should have
+ // connected to another address.
+ if e.State() != transport.DatagramEndpointStateConnected {
+ return WriteContext{}, &tcpip.ErrDestinationRequired{}
+ }
+
+ route.Acquire()
+ } else {
+ // Reject destination address if it goes through a different
+ // NIC than the endpoint was bound to.
+ nicID := opts.To.NIC
+ if nicID == 0 {
+ nicID = tcpip.NICID(e.ops.GetBindToDevice())
+ }
+ info := e.Info()
+ if info.BindNICID != 0 {
+ if nicID != 0 && nicID != info.BindNICID {
+ return WriteContext{}, &tcpip.ErrNoRoute{}
+ }
+
+ nicID = info.BindNICID
+ }
+ if nicID == 0 {
+ nicID = info.RegisterNICID
+ }
+
+ dst, netProto, err := e.checkV4Mapped(*opts.To)
+ if err != nil {
+ return WriteContext{}, err
+ }
+
+ route, _, err = e.connectRouteRLocked(nicID, dst, netProto)
+ if err != nil {
+ return WriteContext{}, err
+ }
+ }
+
+ if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
+ route.Release()
+ return WriteContext{}, &tcpip.ErrBroadcastDisabled{}
+ }
+
+ var tos uint8
+ switch netProto := route.NetProto(); netProto {
+ case header.IPv4ProtocolNumber:
+ tos = e.ipv4TOS
+ case header.IPv6ProtocolNumber:
+ tos = e.ipv6TClass
+ default:
+ panic(fmt.Sprintf("invalid protocol number = %d", netProto))
+ }
+
+ return WriteContext{
+ transProto: e.transProto,
+ route: route,
+ ttl: calculateTTL(route, e.ttl, e.multicastTTL),
+ tos: tos,
+ owner: e.owner,
+ }, nil
+}
+
+// Disconnect disconnects the endpoint from its peer.
+func (e *Endpoint) Disconnect() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if e.State() != transport.DatagramEndpointStateConnected {
+ return
+ }
+
+ info := e.Info()
+ // Exclude ephemerally bound endpoints.
+ if e.wasBound {
+ info.ID = stack.TransportEndpointID{
+ LocalAddress: info.BindAddr,
+ }
+ e.setEndpointState(transport.DatagramEndpointStateBound)
+ } else {
+ info.ID = stack.TransportEndpointID{}
+ e.setEndpointState(transport.DatagramEndpointStateInitial)
+ }
+ e.setInfo(info)
+
+ e.connectedRoute.Release()
+ e.connectedRoute = nil
+}
+
+// connectRouteRLocked establishes a route to the specified interface or the
+// configured multicast interface if no interface is specified and the
+// specified address is a multicast address.
+//
+// TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement.
+// +checklocks:e.mu
+func (e *Endpoint) connectRouteRLocked(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
+ localAddr := e.Info().ID.LocalAddress
+ if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
+ // A packet can only originate from a unicast address (i.e., an interface).
+ localAddr = ""
+ }
+
+ if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
+ if nicID == 0 {
+ nicID = e.multicastNICID
+ }
+ if localAddr == "" && nicID == 0 {
+ localAddr = e.multicastAddr
+ }
+ }
+
+ // Find a route to the desired destination.
+ r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop())
+ if err != nil {
+ return nil, 0, err
+ }
+ return r, nicID, nil
+}
+
+// Connect connects the endpoint to the address.
+func (e *Endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
+ return e.ConnectAndThen(addr, func(_ tcpip.NetworkProtocolNumber, _, _ stack.TransportEndpointID) tcpip.Error {
+ return nil
+ })
+}
+
+// ConnectAndThen connects the endpoint to the address and then calls the
+// provided function.
+//
+// If the function returns an error, the endpoint's state does not change. The
+// function will be called with the network protocol used to connect to the peer
+// and the source and destination addresses that will be used to send traffic to
+// the peer.
+func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error) tcpip.Error {
+ addr.Port = 0
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ info := e.Info()
+ nicID := addr.NIC
+ switch e.State() {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ if info.BindNICID == 0 {
+ break
+ }
+
+ if nicID != 0 && nicID != info.BindNICID {
+ return &tcpip.ErrInvalidEndpointState{}
+ }
+
+ nicID = info.BindNICID
+ default:
+ return &tcpip.ErrInvalidEndpointState{}
+ }
+
+ addr, netProto, err := e.checkV4Mapped(addr)
+ if err != nil {
+ return err
+ }
+
+ r, nicID, err := e.connectRouteRLocked(nicID, addr, netProto)
+ if err != nil {
+ return err
+ }
+
+ id := stack.TransportEndpointID{
+ LocalAddress: info.ID.LocalAddress,
+ RemoteAddress: r.RemoteAddress(),
+ }
+ if e.State() == transport.DatagramEndpointStateInitial {
+ id.LocalAddress = r.LocalAddress()
+ }
+
+ if err := f(r.NetProto(), info.ID, id); err != nil {
+ return err
+ }
+
+ if e.connectedRoute != nil {
+ // If the endpoint was previously connected then release any previous route.
+ e.connectedRoute.Release()
+ }
+ e.connectedRoute = r
+ info.ID = id
+ info.RegisterNICID = nicID
+ e.setInfo(info)
+ e.effectiveNetProto = netProto
+ e.setEndpointState(transport.DatagramEndpointStateConnected)
+ return nil
+}
+
+// Shutdown shutsdown the endpoint.
+func (e *Endpoint) Shutdown() tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ switch state := e.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
+ return &tcpip.ErrNotConnected{}
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ e.writeShutdown = true
+ return nil
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
+ }
+}
+
+// checkV4MappedRLocked determines the effective network protocol and converts
+// addr to its canonical form.
+func (e *Endpoint) checkV4Mapped(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
+ info := e.Info()
+ unwrapped, netProto, err := info.AddrNetProtoLocked(addr, e.ops.GetV6Only())
+ if err != nil {
+ return tcpip.FullAddress{}, 0, err
+ }
+ return unwrapped, netProto, nil
+}
+
+func (e *Endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr)
+}
+
+// Bind binds the endpoint to the address.
+func (e *Endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
+ return e.BindAndThen(addr, func(tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.Error {
+ return nil
+ })
+}
+
+// BindAndThen binds the endpoint to the address and then calls the provided
+// function.
+//
+// If the function returns an error, the endpoint's state does not change. The
+// function will be called with the bound network protocol and address.
+func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.Error) tcpip.Error {
+ addr.Port = 0
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Don't allow binding once endpoint is not in the initial state
+ // anymore.
+ if e.State() != transport.DatagramEndpointStateInitial {
+ return &tcpip.ErrInvalidEndpointState{}
+ }
+
+ addr, netProto, err := e.checkV4Mapped(addr)
+ if err != nil {
+ return err
+ }
+
+ nicID := addr.NIC
+ if len(addr.Addr) != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) {
+ nicID = e.stack.CheckLocalAddress(nicID, netProto, addr.Addr)
+ if nicID == 0 {
+ return &tcpip.ErrBadLocalAddress{}
+ }
+ }
+
+ if err := f(netProto, addr.Addr); err != nil {
+ return err
+ }
+
+ e.wasBound = true
+
+ info := e.Info()
+ info.ID = stack.TransportEndpointID{
+ LocalAddress: addr.Addr,
+ }
+ info.BindNICID = addr.NIC
+ info.RegisterNICID = nicID
+ info.BindAddr = addr.Addr
+ e.setInfo(info)
+ e.effectiveNetProto = netProto
+ e.setEndpointState(transport.DatagramEndpointStateBound)
+ return nil
+}
+
+// WasBound returns true iff the endpoint was ever bound.
+func (e *Endpoint) WasBound() bool {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.wasBound
+}
+
+// GetLocalAddress returns the address that the endpoint is bound to.
+func (e *Endpoint) GetLocalAddress() tcpip.FullAddress {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ info := e.Info()
+ addr := info.BindAddr
+ if e.State() == transport.DatagramEndpointStateConnected {
+ addr = e.connectedRoute.LocalAddress()
+ }
+
+ return tcpip.FullAddress{
+ NIC: info.RegisterNICID,
+ Addr: addr,
+ }
+}
+
+// GetRemoteAddress returns the address that the endpoint is connected to.
+func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.State() != transport.DatagramEndpointStateConnected {
+ return tcpip.FullAddress{}, false
+ }
+
+ return tcpip.FullAddress{
+ Addr: e.connectedRoute.RemoteAddress(),
+ NIC: e.Info().RegisterNICID,
+ }, true
+}
+
+// SetSockOptInt sets the socket option.
+func (e *Endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
+ switch opt {
+ case tcpip.MTUDiscoverOption:
+ // Return not supported if the value is not disabling path
+ // MTU discovery.
+ if v != tcpip.PMTUDiscoveryDont {
+ return &tcpip.ErrNotSupported{}
+ }
+
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ e.multicastTTL = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ e.ttl = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.IPv4TOSOption:
+ e.mu.Lock()
+ e.ipv4TOS = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.IPv6TrafficClassOption:
+ e.mu.Lock()
+ e.ipv6TClass = uint8(v)
+ e.mu.Unlock()
+ }
+
+ return nil
+}
+
+// GetSockOptInt returns the socket option.
+func (e *Endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
+ switch opt {
+ case tcpip.MTUDiscoverOption:
+ // The only supported setting is path MTU discovery disabled.
+ return tcpip.PMTUDiscoveryDont, nil
+
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ v := int(e.multicastTTL)
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ v := int(e.ttl)
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.IPv4TOSOption:
+ e.mu.RLock()
+ v := int(e.ipv4TOS)
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.IPv6TrafficClassOption:
+ e.mu.RLock()
+ v := int(e.ipv6TClass)
+ e.mu.RUnlock()
+ return v, nil
+
+ default:
+ return -1, &tcpip.ErrUnknownProtocolOption{}
+ }
+}
+
+// SetSockOpt sets the socket option.
+func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
+ switch v := opt.(type) {
+ case *tcpip.MulticastInterfaceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
+ fa, netProto, err := e.checkV4Mapped(fa)
+ if err != nil {
+ return err
+ }
+ nic := v.NIC
+ addr := fa.Addr
+
+ if nic == 0 && addr == "" {
+ e.multicastAddr = ""
+ e.multicastNICID = 0
+ break
+ }
+
+ if nic != 0 {
+ if !e.stack.CheckNIC(nic) {
+ return &tcpip.ErrBadLocalAddress{}
+ }
+ } else {
+ nic = e.stack.CheckLocalAddress(0, netProto, addr)
+ if nic == 0 {
+ return &tcpip.ErrBadLocalAddress{}
+ }
+ }
+
+ if info := e.Info(); info.BindNICID != 0 && info.BindNICID != nic {
+ return &tcpip.ErrInvalidEndpointState{}
+ }
+
+ e.multicastNICID = nic
+ e.multicastAddr = addr
+
+ case *tcpip.AddMembershipOption:
+ if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
+ return &tcpip.ErrInvalidOptionValue{}
+ }
+
+ nicID := v.NIC
+
+ if v.InterfaceAddr.Unspecified() {
+ if nicID == 0 {
+ if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.netProto, false /* multicastLoop */); err == nil {
+ nicID = r.NICID()
+ r.Release()
+ }
+ }
+ } else {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return &tcpip.ErrUnknownDevice{}
+ }
+
+ memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if _, ok := e.multicastMemberships[memToInsert]; ok {
+ return &tcpip.ErrPortInUse{}
+ }
+
+ if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.multicastMemberships[memToInsert] = struct{}{}
+
+ case *tcpip.RemoveMembershipOption:
+ if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
+ return &tcpip.ErrInvalidOptionValue{}
+ }
+
+ nicID := v.NIC
+ if v.InterfaceAddr.Unspecified() {
+ if nicID == 0 {
+ if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.netProto, false /* multicastLoop */); err == nil {
+ nicID = r.NICID()
+ r.Release()
+ }
+ }
+ } else {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return &tcpip.ErrUnknownDevice{}
+ }
+
+ memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if _, ok := e.multicastMemberships[memToRemove]; !ok {
+ return &tcpip.ErrBadLocalAddress{}
+ }
+
+ if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ delete(e.multicastMemberships, memToRemove)
+
+ case *tcpip.SocketDetachFilterOption:
+ return nil
+ }
+ return nil
+}
+
+// GetSockOpt returns the socket option.
+func (e *Endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+ switch o := opt.(type) {
+ case *tcpip.MulticastInterfaceOption:
+ e.mu.Lock()
+ *o = tcpip.MulticastInterfaceOption{
+ NIC: e.multicastNICID,
+ InterfaceAddr: e.multicastAddr,
+ }
+ e.mu.Unlock()
+
+ default:
+ return &tcpip.ErrUnknownProtocolOption{}
+ }
+ return nil
+}
+
+// Info returns a copy of the endpoint info.
+func (e *Endpoint) Info() stack.TransportEndpointInfo {
+ e.infoMu.RLock()
+ defer e.infoMu.RUnlock()
+ return e.info
+}
+
+// setInfo sets the endpoint's info.
+//
+// e.mu must be held to synchronize changes to info with the rest of the
+// endpoint.
+//
+// +checklocks:e.mu
+func (e *Endpoint) setInfo(info stack.TransportEndpointInfo) {
+ e.infoMu.Lock()
+ defer e.infoMu.Unlock()
+ e.info = info
+}
diff --git a/pkg/tcpip/transport/internal/network/endpoint_state.go b/pkg/tcpip/transport/internal/network/endpoint_state.go
new file mode 100644
index 000000000..68bd1fbf6
--- /dev/null
+++ b/pkg/tcpip/transport/internal/network/endpoint_state.go
@@ -0,0 +1,58 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package network
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+)
+
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (e *Endpoint) Resume(s *stack.Stack) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ e.stack = s
+
+ for m := range e.multicastMemberships {
+ if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
+ panic(fmt.Sprintf("e.stack.JoinGroup(%d, %d, %s): %s", e.netProto, m.nicID, m.multicastAddr, err))
+ }
+ }
+
+ info := e.Info()
+
+ switch state := e.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
+ case transport.DatagramEndpointStateBound:
+ if len(info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress) {
+ if e.stack.CheckLocalAddress(info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress) == 0 {
+ panic(fmt.Sprintf("got e.stack.CheckLocalAddress(%d, %d, %s) = 0, want != 0", info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress))
+ }
+ }
+ case transport.DatagramEndpointStateConnected:
+ var err tcpip.Error
+ multicastLoop := e.ops.GetMulticastLoop()
+ e.connectedRoute, err = e.stack.FindRoute(info.RegisterNICID, info.ID.LocalAddress, info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop)
+ if err != nil {
+ panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", info.RegisterNICID, info.ID.LocalAddress, info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err))
+ }
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
+ }
+}
diff --git a/pkg/tcpip/transport/internal/network/endpoint_test.go b/pkg/tcpip/transport/internal/network/endpoint_test.go
new file mode 100644
index 000000000..f263a9ea2
--- /dev/null
+++ b/pkg/tcpip/transport/internal/network/endpoint_test.go
@@ -0,0 +1,318 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package network_test
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+)
+
+var (
+ ipv4NICAddr = testutil.MustParse4("1.2.3.4")
+ ipv6NICAddr = testutil.MustParse6("a::1")
+ ipv4RemoteAddr = testutil.MustParse4("6.7.8.9")
+ ipv6RemoteAddr = testutil.MustParse6("b::1")
+)
+
+func TestEndpointStateTransitions(t *testing.T) {
+ const nicID = 1
+
+ data := buffer.View([]byte{1, 2, 4, 5})
+ v4Checker := func(t *testing.T, b buffer.View) {
+ checker.IPv4(t, b,
+ checker.SrcAddr(ipv4NICAddr),
+ checker.DstAddr(ipv4RemoteAddr),
+ checker.IPPayload(data),
+ )
+ }
+
+ v6Checker := func(t *testing.T, b buffer.View) {
+ checker.IPv6(t, b,
+ checker.SrcAddr(ipv6NICAddr),
+ checker.DstAddr(ipv6RemoteAddr),
+ checker.IPPayload(data),
+ )
+ }
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ expectedMaxHeaderLength uint16
+ expectedNetProto tcpip.NetworkProtocolNumber
+ expectedLocalAddr tcpip.Address
+ bindAddr tcpip.Address
+ expectedBoundAddr tcpip.Address
+ remoteAddr tcpip.Address
+ expectedRemoteAddr tcpip.Address
+ checker func(*testing.T, buffer.View)
+ }{
+ {
+ name: "IPv4",
+ netProto: ipv4.ProtocolNumber,
+ expectedMaxHeaderLength: header.IPv4MaximumHeaderSize,
+ expectedNetProto: ipv4.ProtocolNumber,
+ expectedLocalAddr: ipv4NICAddr,
+ bindAddr: header.IPv4AllSystems,
+ expectedBoundAddr: header.IPv4AllSystems,
+ remoteAddr: ipv4RemoteAddr,
+ expectedRemoteAddr: ipv4RemoteAddr,
+ checker: v4Checker,
+ },
+ {
+ name: "IPv6",
+ netProto: ipv6.ProtocolNumber,
+ expectedMaxHeaderLength: header.IPv6FixedHeaderSize,
+ expectedNetProto: ipv6.ProtocolNumber,
+ expectedLocalAddr: ipv6NICAddr,
+ bindAddr: header.IPv6AllNodesMulticastAddress,
+ expectedBoundAddr: header.IPv6AllNodesMulticastAddress,
+ remoteAddr: ipv6RemoteAddr,
+ expectedRemoteAddr: ipv6RemoteAddr,
+ checker: v6Checker,
+ },
+ {
+ name: "IPv4-mapped-IPv6",
+ netProto: ipv6.ProtocolNumber,
+ expectedMaxHeaderLength: header.IPv4MaximumHeaderSize,
+ expectedNetProto: ipv4.ProtocolNumber,
+ expectedLocalAddr: ipv4NICAddr,
+ bindAddr: testutil.MustParse6("::ffff:e000:0001"),
+ expectedBoundAddr: header.IPv4AllSystems,
+ remoteAddr: testutil.MustParse6("::ffff:0607:0809"),
+ expectedRemoteAddr: ipv4RemoteAddr,
+ checker: v4Checker,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: &faketime.NullClock{},
+ })
+ e := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ ipv4ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ipv4NICAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}: %s", nicID, ipv4ProtocolAddr, err)
+ }
+ ipv6ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: ipv6NICAddr.WithPrefix(),
+ }
+
+ if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: ipv4RemoteAddr.WithPrefix().Subnet(), NIC: nicID},
+ {Destination: ipv6RemoteAddr.WithPrefix().Subnet(), NIC: nicID},
+ })
+
+ var ops tcpip.SocketOptions
+ var ep network.Endpoint
+ ep.Init(s, test.netProto, udp.ProtocolNumber, &ops)
+ defer ep.Close()
+ if state := ep.State(); state != transport.DatagramEndpointStateInitial {
+ t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateInitial)
+ }
+
+ bindAddr := tcpip.FullAddress{Addr: test.bindAddr}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
+ }
+ if state := ep.State(); state != transport.DatagramEndpointStateBound {
+ t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateBound)
+ }
+ if diff := cmp.Diff(ep.GetLocalAddress(), tcpip.FullAddress{Addr: test.expectedBoundAddr}); diff != "" {
+ t.Errorf("ep.GetLocalAddress() mismatch (-want +got):\n%s", diff)
+ }
+ if addr, connected := ep.GetRemoteAddress(); connected {
+ t.Errorf("got ep.GetRemoteAddress() = (true, %#v), want = (false, _)", addr)
+ }
+
+ connectAddr := tcpip.FullAddress{Addr: test.remoteAddr}
+ if err := ep.Connect(connectAddr); err != nil {
+ t.Fatalf("ep.Connect(%#v): %s", connectAddr, err)
+ }
+ if state := ep.State(); state != transport.DatagramEndpointStateConnected {
+ t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateConnected)
+ }
+ if diff := cmp.Diff(ep.GetLocalAddress(), tcpip.FullAddress{Addr: test.expectedLocalAddr}); diff != "" {
+ t.Errorf("ep.GetLocalAddress() mismatch (-want +got):\n%s", diff)
+ }
+ if addr, connected := ep.GetRemoteAddress(); !connected {
+ t.Errorf("got ep.GetRemoteAddress() = (false, _), want = (true, %#v)", connectAddr)
+ } else if diff := cmp.Diff(addr, tcpip.FullAddress{Addr: test.expectedRemoteAddr}); diff != "" {
+ t.Errorf("remote address mismatch (-want +got):\n%s", diff)
+ }
+
+ ctx, err := ep.AcquireContextForWrite(tcpip.WriteOptions{})
+ if err != nil {
+ t.Fatalf("ep.AcquireContexForWrite({}): %s", err)
+ }
+ defer ctx.Release()
+ info := ctx.PacketInfo()
+ if diff := cmp.Diff(network.WritePacketInfo{
+ NetProto: test.expectedNetProto,
+ LocalAddress: test.expectedLocalAddr,
+ RemoteAddress: test.expectedRemoteAddr,
+ MaxHeaderLength: test.expectedMaxHeaderLength,
+ RequiresTXTransportChecksum: true,
+ }, info); diff != "" {
+ t.Errorf("write packet info mismatch (-want +got):\n%s", diff)
+ }
+ if err := ctx.WritePacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(info.MaxHeaderLength),
+ Data: data.ToVectorisedView(),
+ }), false /* headerIncluded */); err != nil {
+ t.Fatalf("ctx.WritePacket(_, false): %s", err)
+ }
+ if pkt, ok := e.Read(); !ok {
+ t.Fatalf("expected packet to be read from link endpoint")
+ } else {
+ test.checker(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()))
+ }
+
+ ep.Close()
+ if state := ep.State(); state != transport.DatagramEndpointStateClosed {
+ t.Fatalf("got ep.State() = %s, want = %s", state, transport.DatagramEndpointStateClosed)
+ }
+ })
+ }
+}
+
+func TestBindNICID(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ bindAddr tcpip.Address
+ unicast bool
+ }{
+ {
+ name: "IPv4 multicast",
+ netProto: ipv4.ProtocolNumber,
+ bindAddr: header.IPv4AllSystems,
+ unicast: false,
+ },
+ {
+ name: "IPv6 multicast",
+ netProto: ipv6.ProtocolNumber,
+ bindAddr: header.IPv6AllNodesMulticastAddress,
+ unicast: false,
+ },
+ {
+ name: "IPv4 unicast",
+ netProto: ipv4.ProtocolNumber,
+ bindAddr: ipv4NICAddr,
+ unicast: true,
+ },
+ {
+ name: "IPv6 unicast",
+ netProto: ipv6.ProtocolNumber,
+ bindAddr: ipv6NICAddr,
+ unicast: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, testBindNICID := range []tcpip.NICID{0, nicID} {
+ t.Run(fmt.Sprintf("BindNICID=%d", testBindNICID), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: &faketime.NullClock{},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ ipv4ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ipv4NICAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtocolAddr, err)
+ }
+ ipv6ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: ipv6NICAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err)
+ }
+
+ var ops tcpip.SocketOptions
+ var ep network.Endpoint
+ ep.Init(s, test.netProto, udp.ProtocolNumber, &ops)
+ defer ep.Close()
+ if ep.WasBound() {
+ t.Fatal("got ep.WasBound() = true, want = false")
+ }
+ wantInfo := stack.TransportEndpointInfo{NetProto: test.netProto, TransProto: udp.ProtocolNumber}
+ if diff := cmp.Diff(wantInfo, ep.Info()); diff != "" {
+ t.Fatalf("ep.Info() mismatch (-want +got):\n%s", diff)
+ }
+
+ bindAddr := tcpip.FullAddress{Addr: test.bindAddr, NIC: testBindNICID}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
+ }
+ if !ep.WasBound() {
+ t.Error("got ep.WasBound() = false, want = true")
+ }
+ wantInfo.ID = stack.TransportEndpointID{LocalAddress: bindAddr.Addr}
+ wantInfo.BindAddr = bindAddr.Addr
+ wantInfo.BindNICID = bindAddr.NIC
+ if test.unicast {
+ wantInfo.RegisterNICID = nicID
+ } else {
+ wantInfo.RegisterNICID = bindAddr.NIC
+ }
+ if diff := cmp.Diff(wantInfo, ep.Info()); diff != "" {
+ t.Errorf("ep.Info() mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 8e7bb6c6e..80eef39e9 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -25,7 +25,6 @@
package packet
import (
- "fmt"
"io"
"time"
@@ -60,52 +59,47 @@ type packet struct {
//
// +stateify savable
type endpoint struct {
- stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
- netProto tcpip.NetworkProtocolNumber
waiterQueue *waiter.Queue
cooked bool
-
- // The following fields are used to manage the receive queue and are
- // protected by rcvMu.
- rcvMu sync.Mutex `state:"nosave"`
- rcvList packetList
+ ops tcpip.SocketOptions
+ stats tcpip.TransportEndpointStats
+
+ // The following fields are used to manage the receive queue.
+ rcvMu sync.Mutex `state:"nosave"`
+ // +checklocks:rcvMu
+ rcvList packetList
+ // +checklocks:rcvMu
rcvBufSize int
- rcvClosed bool
-
- // The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- closed bool
- stats tcpip.TransportEndpointStats `state:"nosave"`
- bound bool
+ // +checklocks:rcvMu
+ rcvClosed bool
+ // +checklocks:rcvMu
+ rcvDisabled bool
+
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
+ closed bool
+ // +checklocks:mu
+ boundNetProto tcpip.NetworkProtocolNumber
+ // +checklocks:mu
boundNIC tcpip.NICID
- // lastErrorMu protects lastError.
lastErrorMu sync.Mutex `state:"nosave"`
- lastError tcpip.Error
-
- // ops is used to get socket level options.
- ops tcpip.SocketOptions
-
- // frozen indicates if the packets should be delivered to the endpoint
- // during restore.
- frozen bool
+ // +checklocks:lastErrorMu
+ lastError tcpip.Error
}
// NewEndpoint returns a new packet endpoint.
func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
ep := &endpoint{
- stack: s,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- },
- cooked: cooked,
- netProto: netProto,
- waiterQueue: waiterQueue,
+ stack: s,
+ cooked: cooked,
+ boundNetProto: netProto,
+ waiterQueue: waiterQueue,
}
ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
ep.ops.SetReceiveBufferSize(32*1024, false /* notify */)
@@ -141,7 +135,7 @@ func (ep *endpoint) Close() {
return
}
- ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+ ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep)
ep.rcvMu.Lock()
defer ep.rcvMu.Unlock()
@@ -154,7 +148,6 @@ func (ep *endpoint) Close() {
}
ep.closed = true
- ep.bound = false
ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
@@ -189,7 +182,7 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul
Total: packet.data.Size(),
ControlMessages: tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: packet.receivedAt.UnixNano(),
+ Timestamp: packet.receivedAt,
},
}
if opts.NeedRemoteAddr {
@@ -207,8 +200,52 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul
return res, nil
}
-func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, tcpip.Error) {
- return 0, &tcpip.ErrInvalidOptionValue{}
+func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
+ if !ep.stack.PacketEndpointWriteSupported() {
+ return 0, &tcpip.ErrNotSupported{}
+ }
+
+ ep.mu.Lock()
+ closed := ep.closed
+ nicID := ep.boundNIC
+ proto := ep.boundNetProto
+ ep.mu.Unlock()
+ if closed {
+ return 0, &tcpip.ErrClosedForSend{}
+ }
+
+ var remote tcpip.LinkAddress
+ if to := opts.To; to != nil {
+ remote = tcpip.LinkAddress(to.Addr)
+
+ if n := to.NIC; n != 0 {
+ nicID = n
+ }
+
+ if p := to.Port; p != 0 {
+ proto = tcpip.NetworkProtocolNumber(p)
+ }
+ }
+
+ if nicID == 0 {
+ return 0, &tcpip.ErrInvalidOptionValue{}
+ }
+
+ // TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
+ payloadBytes := make(buffer.View, p.Len())
+ if _, err := io.ReadFull(p, payloadBytes); err != nil {
+ return 0, &tcpip.ErrBadBuffer{}
+ }
+
+ if err := func() tcpip.Error {
+ if ep.cooked {
+ return ep.stack.WritePacketToRemote(nicID, remote, proto, payloadBytes.ToVectorisedView())
+ }
+ return ep.stack.WriteRawPacket(nicID, proto, payloadBytes.ToVectorisedView())
+ }(); err != nil {
+ return 0, err
+ }
+ return int64(len(payloadBytes)), nil
}
// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be
@@ -253,29 +290,42 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- if ep.bound && ep.boundNIC == addr.NIC {
- // If the NIC being bound is the same then just return success.
+ netProto := tcpip.NetworkProtocolNumber(addr.Port)
+ if netProto == 0 {
+ // Do not allow unbinding the network protocol.
+ netProto = ep.boundNetProto
+ }
+
+ if ep.boundNIC == addr.NIC && ep.boundNetProto == netProto {
+ // Already bound to the requested NIC and network protocol.
return nil
}
- // Unregister endpoint with all the nics.
- ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
- ep.bound = false
+ // TODO(https://gvisor.dev/issue/6618): Unregister after registering the new
+ // binding.
+ ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep)
+ ep.boundNIC = 0
+ ep.boundNetProto = 0
// Bind endpoint to receive packets from specific interface.
- if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil {
+ if err := ep.stack.RegisterPacketEndpoint(addr.NIC, netProto, ep); err != nil {
return err
}
- ep.bound = true
ep.boundNIC = addr.NIC
-
+ ep.boundNetProto = netProto
return nil
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
-func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
- return tcpip.FullAddress{}, &tcpip.ErrNotSupported{}
+func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
+ ep.mu.RLock()
+ defer ep.mu.RUnlock()
+
+ return tcpip.FullAddress{
+ NIC: ep.boundNIC,
+ Port: uint16(ep.boundNetProto),
+ }, nil
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
@@ -359,7 +409,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
}
// HandlePacket implements stack.PacketEndpoint.HandlePacket.
-func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+func (ep *endpoint) HandlePacket(nicID tcpip.NICID, _ tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
ep.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
@@ -371,7 +421,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
}
rcvBufSize := ep.ops.GetReceiveBufferSize()
- if ep.frozen || ep.rcvBufSize >= int(rcvBufSize) {
+ if ep.rcvDisabled || ep.rcvBufSize >= int(rcvBufSize) {
ep.rcvMu.Unlock()
ep.stack.Stats().DroppedPackets.Increment()
ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
@@ -380,76 +430,39 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
wasEmpty := ep.rcvBufSize == 0
- // Push new packet into receive list and increment the buffer size.
- var packet packet
+ rcvdPkt := packet{
+ packetInfo: tcpip.LinkPacketInfo{
+ Protocol: netProto,
+ PktType: pkt.PktType,
+ },
+ senderAddr: tcpip.FullAddress{
+ NIC: nicID,
+ },
+ receivedAt: ep.stack.Clock().Now(),
+ }
+
if !pkt.LinkHeader().View().IsEmpty() {
- // Get info directly from the ethernet header.
hdr := header.Ethernet(pkt.LinkHeader().View())
- packet.senderAddr = tcpip.FullAddress{
- NIC: nicID,
- Addr: tcpip.Address(hdr.SourceAddress()),
- }
- packet.packetInfo.Protocol = netProto
- packet.packetInfo.PktType = pkt.PktType
- } else {
- // Guess the would-be ethernet header.
- packet.senderAddr = tcpip.FullAddress{
- NIC: nicID,
- Addr: tcpip.Address(localAddr),
- }
- packet.packetInfo.Protocol = netProto
- packet.packetInfo.PktType = pkt.PktType
+ rcvdPkt.senderAddr.Addr = tcpip.Address(hdr.SourceAddress())
}
if ep.cooked {
- // Cooked packets can simply be queued.
- switch pkt.PktType {
- case tcpip.PacketHost:
- packet.data = pkt.Data().ExtractVV()
- case tcpip.PacketOutgoing:
- // Strip Link Header.
- var combinedVV buffer.VectorisedView
- if v := pkt.NetworkHeader().View(); !v.IsEmpty() {
- combinedVV.AppendView(v)
- }
- if v := pkt.TransportHeader().View(); !v.IsEmpty() {
- combinedVV.AppendView(v)
- }
- combinedVV.Append(pkt.Data().ExtractVV())
- packet.data = combinedVV
- default:
- panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt))
+ // Cooked packet endpoints don't include the link-headers in received
+ // packets.
+ if v := pkt.NetworkHeader().View(); !v.IsEmpty() {
+ rcvdPkt.data.AppendView(v)
}
- } else {
- // Raw packets need their ethernet headers prepended before
- // queueing.
- var linkHeader buffer.View
- if pkt.PktType != tcpip.PacketOutgoing {
- if pkt.LinkHeader().View().IsEmpty() {
- // We weren't provided with an actual ethernet header,
- // so fake one.
- ethFields := header.EthernetFields{
- SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
- DstAddr: localAddr,
- Type: netProto,
- }
- fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
- fakeHeader.Encode(&ethFields)
- linkHeader = buffer.View(fakeHeader)
- } else {
- linkHeader = append(buffer.View(nil), pkt.LinkHeader().View()...)
- }
- combinedVV := linkHeader.ToVectorisedView()
- combinedVV.Append(pkt.Data().ExtractVV())
- packet.data = combinedVV
- } else {
- packet.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views())
+ if v := pkt.TransportHeader().View(); !v.IsEmpty() {
+ rcvdPkt.data.AppendView(v)
}
+ rcvdPkt.data.Append(pkt.Data().ExtractVV())
+ } else {
+ // Raw packet endpoints include link-headers in received packets.
+ rcvdPkt.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views())
}
- packet.receivedAt = ep.stack.Clock().Now()
- ep.rcvList.PushBack(&packet)
- ep.rcvBufSize += packet.data.Size()
+ ep.rcvList.PushBack(&rcvdPkt)
+ ep.rcvBufSize += rcvdPkt.data.Size()
ep.rcvMu.Unlock()
ep.stats.PacketsReceived.Increment()
@@ -467,10 +480,8 @@ func (*endpoint) State() uint32 {
// Info returns a copy of the endpoint info.
func (ep *endpoint) Info() tcpip.EndpointInfo {
ep.mu.RLock()
- // Make a copy of the endpoint info.
- ret := ep.TransportEndpointInfo
- ep.mu.RUnlock()
- return &ret
+ defer ep.mu.RUnlock()
+ return &stack.TransportEndpointInfo{NetProto: ep.boundNetProto}
}
// Stats returns a pointer to the endpoint stats.
@@ -485,18 +496,3 @@ func (*endpoint) SetOwner(tcpip.PacketOwner) {}
func (ep *endpoint) SocketOptions() *tcpip.SocketOptions {
return &ep.ops
}
-
-// freeze prevents any more packets from being delivered to the endpoint.
-func (ep *endpoint) freeze() {
- ep.mu.Lock()
- ep.frozen = true
- ep.mu.Unlock()
-}
-
-// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows
-// new packets to be delivered again.
-func (ep *endpoint) thaw() {
- ep.mu.Lock()
- ep.frozen = false
- ep.mu.Unlock()
-}
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index e729921db..88cd80ad3 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -15,6 +15,7 @@
package packet
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -34,33 +35,34 @@ func (p *packet) loadReceivedAt(nsec int64) {
// saveData saves packet.data field.
func (p *packet) saveData() buffer.VectorisedView {
- // We cannot save p.data directly as p.data.views may alias to p.views,
- // which is not allowed by state framework (in-struct pointer).
return p.data.Clone(nil)
}
// loadData loads packet.data field.
func (p *packet) loadData(data buffer.VectorisedView) {
- // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
- // here because data.views is not guaranteed to be loaded by now. Plus,
- // data.views will be allocated anyway so there really is little point
- // of utilizing p.views for data.views.
p.data = data
}
// beforeSave is invoked by stateify.
func (ep *endpoint) beforeSave() {
- ep.freeze()
+ ep.rcvMu.Lock()
+ defer ep.rcvMu.Unlock()
+ ep.rcvDisabled = true
}
// afterLoad is invoked by stateify.
func (ep *endpoint) afterLoad() {
- ep.thaw()
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
ep.stack = stack.StackFromEnv
ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
- // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC.
- if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil {
- panic(err)
+ if err := ep.stack.RegisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep); err != nil {
+ panic(fmt.Sprintf("RegisterPacketEndpoint(%d, %d, _): %s", ep.boundNIC, ep.boundNetProto, err))
}
+
+ ep.rcvMu.Lock()
+ ep.rcvDisabled = false
+ ep.rcvMu.Unlock()
}
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index 2eab09088..b7e97e218 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -33,6 +33,8 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
+ "//pkg/tcpip/transport/internal/network",
"//pkg/tcpip/transport/packet",
"//pkg/waiter",
],
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index b3d8951ff..181b478d0 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -26,6 +26,7 @@
package raw
import (
+ "fmt"
"io"
"time"
@@ -34,6 +35,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -57,15 +60,19 @@ type rawPacket struct {
//
// +stateify savable
type endpoint struct {
- stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
+ transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
associated bool
+ net network.Endpoint
+ stats tcpip.TransportEndpointStats
+ ops tcpip.SocketOptions
+
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
rcvMu sync.Mutex `state:"nosave"`
@@ -74,20 +81,7 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- closed bool
- connected bool
- bound bool
- // route is the route to a remote network endpoint. It is set via
- // Connect(), and is valid only when conneted is true.
- route *stack.Route `state:"manual"`
- stats tcpip.TransportEndpointStats `state:"nosave"`
- // owner is used to get uid and gid of the packet.
- owner tcpip.PacketOwner
-
- // ops is used to get socket level options.
- ops tcpip.SocketOptions
-
+ mu sync.RWMutex `state:"nosave"`
// frozen indicates if the packets should be delivered to the endpoint
// during restore.
frozen bool
@@ -99,16 +93,9 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, tcpip.Error) {
- if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber {
- return nil, &tcpip.ErrUnknownProtocol{}
- }
-
e := &endpoint{
- stack: s,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- TransProto: transProto,
- },
+ stack: s,
+ transProto: transProto,
waiterQueue: waiterQueue,
associated: associated,
}
@@ -116,6 +103,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
e.ops.SetHeaderIncluded(!associated)
e.ops.SetSendBufferSize(32*1024, false /* notify */)
e.ops.SetReceiveBufferSize(32*1024, false /* notify */)
+ e.net.Init(s, netProto, transProto, &e.ops)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -137,7 +125,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
return e, nil
}
- if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
return nil, err
}
@@ -154,11 +142,17 @@ func (e *endpoint) Close() {
e.mu.Lock()
defer e.mu.Unlock()
- if e.closed || !e.associated {
+ if e.net.State() == transport.DatagramEndpointStateClosed {
return
}
- e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e)
+ e.net.Close()
+
+ if !e.associated {
+ return
+ }
+
+ e.stack.UnregisterRawTransportEndpoint(e.net.NetProto(), e.transProto, e)
e.rcvMu.Lock()
defer e.rcvMu.Unlock()
@@ -170,15 +164,6 @@ func (e *endpoint) Close() {
e.rcvList.Remove(e.rcvList.Front())
}
- e.connected = false
-
- if e.route != nil {
- e.route.Release()
- e.route = nil
- }
-
- e.closed = true
-
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
@@ -186,9 +171,7 @@ func (e *endpoint) Close() {
func (*endpoint) ModerateRecvBuf(int) {}
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
- e.mu.Lock()
- defer e.mu.Unlock()
- e.owner = owner
+ e.net.SetOwner(owner)
}
// Read implements tcpip.Endpoint.Read.
@@ -219,7 +202,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
Total: pkt.data.Size(),
ControlMessages: tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: pkt.receivedAt.UnixNano(),
+ Timestamp: pkt.receivedAt,
},
}
if opts.NeedRemoteAddr {
@@ -236,14 +219,15 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// Write implements tcpip.Endpoint.Write.
func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
+ netProto := e.net.NetProto()
// We can create, but not write to, unassociated IPv6 endpoints.
- if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber {
+ if !e.associated && netProto == header.IPv6ProtocolNumber {
return 0, &tcpip.ErrInvalidOptionValue{}
}
if opts.To != nil {
// Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint.
- if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize {
+ if netProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize {
return 0, &tcpip.ErrInvalidOptionValue{}
}
}
@@ -269,79 +253,25 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
}
func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
- // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
- if opts.More {
- return 0, &tcpip.ErrInvalidOptionValue{}
+ ctx, err := e.net.AcquireContextForWrite(opts)
+ if err != nil {
+ return 0, err
}
- payloadBytes, route, owner, err := func() ([]byte, *stack.Route, tcpip.PacketOwner, tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
-
- if e.closed {
- return nil, nil, nil, &tcpip.ErrInvalidEndpointState{}
- }
-
- payloadBytes := make([]byte, p.Len())
- if _, err := io.ReadFull(p, payloadBytes); err != nil {
- return nil, nil, nil, &tcpip.ErrBadBuffer{}
- }
-
- // Did the user caller provide a destination? If not, use the connected
- // destination.
- if opts.To == nil {
- // If the user doesn't specify a destination, they should have
- // connected to another address.
- if !e.connected {
- return nil, nil, nil, &tcpip.ErrDestinationRequired{}
- }
- e.route.Acquire()
-
- return payloadBytes, e.route, e.owner, nil
- }
-
- // The caller provided a destination. Reject destination address if it
- // goes through a different NIC than the endpoint was bound to.
- nic := opts.To.NIC
- if e.bound && nic != 0 && nic != e.BindNICID {
- return nil, nil, nil, &tcpip.ErrNoRoute{}
- }
+ // TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
+ payloadBytes := make([]byte, p.Len())
+ if _, err := io.ReadFull(p, payloadBytes); err != nil {
+ return 0, &tcpip.ErrBadBuffer{}
+ }
- // Find the route to the destination. If BindAddress is 0,
- // FindRoute will choose an appropriate source address.
- route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
- if err != nil {
- return nil, nil, nil, err
- }
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(ctx.PacketInfo().MaxHeaderLength),
+ Data: buffer.View(payloadBytes).ToVectorisedView(),
+ })
- return payloadBytes, route, e.owner, nil
- }()
- if err != nil {
+ if err := ctx.WritePacket(pkt, e.ops.GetHeaderIncluded()); err != nil {
return 0, err
}
- defer route.Release()
-
- if e.ops.GetHeaderIncluded() {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.View(payloadBytes).ToVectorisedView(),
- })
- if err := route.WriteHeaderIncludedPacket(pkt); err != nil {
- return 0, err
- }
- } else {
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(route.MaxHeaderLength()),
- Data: buffer.View(payloadBytes).ToVectorisedView(),
- })
- pkt.Owner = owner
- if err := route.WritePacket(stack.NetworkHeaderParams{
- Protocol: e.TransProto,
- TTL: route.DefaultTTL(),
- TOS: stack.DefaultTOS,
- }, pkt); err != nil {
- return 0, err
- }
- }
return int64(len(payloadBytes)), nil
}
@@ -353,66 +283,29 @@ func (*endpoint) Disconnect() tcpip.Error {
// Connect implements tcpip.Endpoint.Connect.
func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
+ netProto := e.net.NetProto()
+
// Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint.
- if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize {
+ if netProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize {
return &tcpip.ErrAddressFamilyNotSupported{}
}
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if e.closed {
- return &tcpip.ErrInvalidEndpointState{}
- }
-
- nic := addr.NIC
- if e.bound {
- if e.BindNICID == 0 {
- // If we're bound, but not to a specific NIC, the NIC
- // in addr will be used. Nothing to do here.
- } else if addr.NIC == 0 {
- // If we're bound to a specific NIC, but addr doesn't
- // specify a NIC, use the bound NIC.
- nic = e.BindNICID
- } else if addr.NIC != e.BindNICID {
- // We're bound and addr specifies a NIC. They must be
- // the same.
- return &tcpip.ErrInvalidEndpointState{}
- }
- }
-
- // Find a route to the destination.
- route, err := e.stack.FindRoute(nic, "", addr.Addr, e.NetProto, false)
- if err != nil {
- return err
- }
-
- if e.associated {
- // Re-register the endpoint with the appropriate NIC.
- if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
- route.Release()
- return err
+ return e.net.ConnectAndThen(addr, func(_ tcpip.NetworkProtocolNumber, _, _ stack.TransportEndpointID) tcpip.Error {
+ if e.associated {
+ // Re-register the endpoint with the appropriate NIC.
+ if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
+ return err
+ }
+ e.stack.UnregisterRawTransportEndpoint(netProto, e.transProto, e)
}
- e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e)
- e.RegisterNICID = nic
- }
-
- if e.route != nil {
- // If the endpoint was previously connected then release any previous route.
- e.route.Release()
- }
- e.route = route
- e.connected = true
- return nil
+ return nil
+ })
}
// Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets.
func (e *endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if !e.connected {
+ if e.net.State() != transport.DatagramEndpointStateConnected {
return &tcpip.ErrNotConnected{}
}
return nil
@@ -430,46 +323,26 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
// Bind implements tcpip.Endpoint.Bind.
func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- // If a local address was specified, verify that it's valid.
- if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, addr.Addr) == 0 {
- return &tcpip.ErrBadLocalAddress{}
- }
+ return e.net.BindAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, _ tcpip.Address) tcpip.Error {
+ if !e.associated {
+ return nil
+ }
- if e.associated {
// Re-register the endpoint with the appropriate NIC.
- if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
return err
}
- e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e)
- e.RegisterNICID = addr.NIC
- e.BindNICID = addr.NIC
- }
-
- e.BindAddr = addr.Addr
- e.bound = true
-
- return nil
+ e.stack.UnregisterRawTransportEndpoint(netProto, e.transProto, e)
+ return nil
+ })
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
-
- addr := e.BindAddr
- if e.connected {
- addr = e.route.LocalAddress()
- }
-
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: addr,
- // Linux returns the protocol in the port field.
- Port: uint16(e.TransProto),
- }, nil
+ a := e.net.GetLocalAddress()
+ // Linux returns the protocol in the port field.
+ a.Port = uint16(e.transProto)
+ return a, nil
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
@@ -502,17 +375,17 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
return nil
default:
- return &tcpip.ErrUnknownProtocolOption{}
+ return e.net.SetSockOpt(opt)
}
}
-func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error {
- return &tcpip.ErrUnknownProtocolOption{}
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
+ return e.net.SetSockOptInt(opt, v)
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error {
- return &tcpip.ErrUnknownProtocolOption{}
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+ return e.net.GetSockOpt(opt)
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
@@ -529,100 +402,108 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
return v, nil
default:
- return -1, &tcpip.ErrUnknownProtocolOption{}
+ return e.net.GetSockOptInt(opt)
}
}
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
- e.mu.RLock()
- e.rcvMu.Lock()
+ notifyReadableEvents := func() bool {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ e.rcvMu.Lock()
+ defer e.rcvMu.Unlock()
+
+ // Drop the packet if our buffer is currently full or if this is an unassociated
+ // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only
+ // See: https://man7.org/linux/man-pages/man7/raw.7.html
+ //
+ // An IPPROTO_RAW socket is send only. If you really want to receive
+ // all IP packets, use a packet(7) socket with the ETH_P_IP protocol.
+ // Note that packet sockets don't reassemble IP fragments, unlike raw
+ // sockets.
+ if e.rcvClosed || !e.associated {
+ e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.ClosedReceiver.Increment()
+ return false
+ }
- // Drop the packet if our buffer is currently full or if this is an unassociated
- // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only
- // See: https://man7.org/linux/man-pages/man7/raw.7.html
- //
- // An IPPROTO_RAW socket is send only. If you really want to receive
- // all IP packets, use a packet(7) socket with the ETH_P_IP protocol.
- // Note that packet sockets don't reassemble IP fragments, unlike raw
- // sockets.
- if e.rcvClosed || !e.associated {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- e.stack.Stats().DroppedPackets.Increment()
- e.stats.ReceiveErrors.ClosedReceiver.Increment()
- return
- }
+ rcvBufSize := e.ops.GetReceiveBufferSize()
+ if e.frozen || e.rcvBufSize >= int(rcvBufSize) {
+ e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
+ return false
+ }
- rcvBufSize := e.ops.GetReceiveBufferSize()
- if e.frozen || e.rcvBufSize >= int(rcvBufSize) {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- e.stack.Stats().DroppedPackets.Increment()
- e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
- return
- }
+ srcAddr := pkt.Network().SourceAddress()
+ info := e.net.Info()
- remoteAddr := pkt.Network().SourceAddress()
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateConnected:
+ // If connected, only accept packets from the remote address we
+ // connected to.
+ if info.ID.RemoteAddress != srcAddr {
+ return false
+ }
- if e.bound {
- // If bound to a NIC, only accept data for that NIC.
- if e.BindNICID != 0 && e.BindNICID != pkt.NICID {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- return
- }
- // If bound to an address, only accept data for that address.
- if e.BindAddr != "" && e.BindAddr != remoteAddr {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- return
+ // Connected sockets may also have been bound to a specific
+ // address/NIC.
+ fallthrough
+ case transport.DatagramEndpointStateBound:
+ // If bound to a NIC, only accept data for that NIC.
+ if info.BindNICID != 0 && info.BindNICID != pkt.NICID {
+ return false
+ }
+
+ // If bound to an address, only accept data for that address.
+ if info.BindAddr != "" && info.BindAddr != pkt.Network().DestinationAddress() {
+ return false
+ }
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
- }
- // If connected, only accept packets from the remote address we
- // connected to.
- if e.connected && e.route.RemoteAddress() != remoteAddr {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- return
- }
+ wasEmpty := e.rcvBufSize == 0
- wasEmpty := e.rcvBufSize == 0
+ // Push new packet into receive list and increment the buffer size.
+ packet := &rawPacket{
+ senderAddr: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: srcAddr,
+ },
+ }
- // Push new packet into receive list and increment the buffer size.
- packet := &rawPacket{
- senderAddr: tcpip.FullAddress{
- NIC: pkt.NICID,
- Addr: remoteAddr,
- },
- }
+ // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not.
+ // We copy headers' underlying bytes because pkt.*Header may point to
+ // the middle of a slice, and another struct may point to the "outer"
+ // slice. Save/restore doesn't support overlapping slices and will fail.
+ //
+ // TODO(https://gvisor.dev/issue/6517): Avoid the copy once S/R supports
+ // overlapping slices.
+ var combinedVV buffer.VectorisedView
+ if info.NetProto == header.IPv4ProtocolNumber {
+ networkHeader, transportHeader := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
+ headers := make(buffer.View, 0, len(networkHeader)+len(transportHeader))
+ headers = append(headers, networkHeader...)
+ headers = append(headers, transportHeader...)
+ combinedVV = headers.ToVectorisedView()
+ } else {
+ combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView()
+ }
+ combinedVV.Append(pkt.Data().ExtractVV())
+ packet.data = combinedVV
+ packet.receivedAt = e.stack.Clock().Now()
- // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not.
- // We copy headers' underlying bytes because pkt.*Header may point to
- // the middle of a slice, and another struct may point to the "outer"
- // slice. Save/restore doesn't support overlapping slices and will fail.
- var combinedVV buffer.VectorisedView
- if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber {
- network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
- headers := make(buffer.View, 0, len(network)+len(transport))
- headers = append(headers, network...)
- headers = append(headers, transport...)
- combinedVV = headers.ToVectorisedView()
- } else {
- combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView()
- }
- combinedVV.Append(pkt.Data().ExtractVV())
- packet.data = combinedVV
- packet.receivedAt = e.stack.Clock().Now()
+ e.rcvList.PushBack(packet)
+ e.rcvBufSize += packet.data.Size()
+ e.stats.PacketsReceived.Increment()
- e.rcvList.PushBack(packet)
- e.rcvBufSize += packet.data.Size()
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- e.stats.PacketsReceived.Increment()
- // Notify waiters that there's data to be read.
- if wasEmpty {
+ // Notify waiters that there is data to be read now.
+ return wasEmpty
+ }()
+
+ if notifyReadableEvents {
e.waiterQueue.Notify(waiter.ReadableEvents)
}
}
@@ -634,10 +515,7 @@ func (e *endpoint) State() uint32 {
// Info returns a copy of the endpoint info.
func (e *endpoint) Info() tcpip.EndpointInfo {
- e.mu.RLock()
- // Make a copy of the endpoint info.
- ret := e.TransportEndpointInfo
- e.mu.RUnlock()
+ ret := e.net.Info()
return &ret
}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 39669b445..e74713064 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -15,6 +15,7 @@
package raw
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -60,35 +61,16 @@ func (e *endpoint) beforeSave() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
+ e.net.Resume(s)
+
e.thaw()
e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
- // If the endpoint is connected, re-connect.
- if e.connected {
- var err tcpip.Error
- // TODO(gvisor.dev/issue/4906): Properly restore the route with the right
- // remote address. We used to pass e.remote.RemoteAddress which was
- // effectively the empty address but since moving e.route to hold a pointer
- // to a route instead of the route by value, we pass the empty address
- // directly. Obviously this was always wrong since we should provide the
- // remote address we were connected to, to properly restore the route.
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, "", e.NetProto, false)
- if err != nil {
- panic(err)
- }
- }
-
- // If the endpoint is bound, re-bind.
- if e.bound {
- if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.BindAddr) == 0 {
- panic(&tcpip.ErrBadLocalAddress{})
- }
- }
-
if e.associated {
- if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
- panic(err)
+ netProto := e.net.NetProto()
+ if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
+ panic(fmt.Sprintf("e.stack.RegisterRawTransportEndpoint(%d, %d, _): %s", netProto, e.transProto, err))
}
}
}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 8436d2cf0..20958d882 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -68,6 +68,7 @@ go_library(
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
"//pkg/tcpip/header/parse",
+ "//pkg/tcpip/internal/tcp",
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/stack",
@@ -79,9 +80,10 @@ go_library(
go_test(
name = "tcp_x_test",
- size = "medium",
+ size = "large",
srcs = [
"dual_stack_test.go",
+ "rcv_test.go",
"sack_scoreboard_test.go",
"tcp_noracedetector_test.go",
"tcp_rack_test.go",
@@ -96,6 +98,7 @@ go_test(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/link/sniffer",
@@ -112,16 +115,6 @@ go_test(
)
go_test(
- name = "rcv_test",
- size = "small",
- srcs = ["rcv_test.go"],
- deps = [
- "//pkg/tcpip/header",
- "//pkg/tcpip/seqnum",
- ],
-)
-
-go_test(
name = "tcp_test",
size = "small",
srcs = [
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index aa413ad05..caf14b0dc 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -15,12 +15,12 @@
package tcp
import (
+ "container/list"
"crypto/sha1"
"encoding/binary"
"fmt"
"hash"
"io"
- "sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/sleep"
@@ -72,7 +72,8 @@ func encodeMSS(mss uint16) uint32 {
// and must not be accessed or have its methods called concurrently as they
// may mutate the stored objects.
type listenContext struct {
- stack *stack.Stack
+ stack *stack.Stack
+ protocol *protocol
// rcvWnd is the receive window that is sent by this listening context
// in the initial SYN-ACK.
@@ -99,18 +100,6 @@ type listenContext struct {
// netProto indicates the network protocol(IPv4/v6) for the listening
// endpoint.
netProto tcpip.NetworkProtocolNumber
-
- // pendingMu protects pendingEndpoints. This should only be accessed
- // by the listening endpoint's worker goroutine.
- //
- // Lock Ordering: listenEP.workerMu -> pendingMu
- pendingMu sync.Mutex
- // pending is used to wait for all pendingEndpoints to finish when
- // a socket is closed.
- pending sync.WaitGroup
- // pendingEndpoints is a map of all endpoints for which a handshake is
- // in progress.
- pendingEndpoints map[stack.TransportEndpointID]*endpoint
}
// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds.
@@ -119,15 +108,15 @@ func timeStamp(clock tcpip.Clock) uint32 {
}
// newListenContext creates a new listen context.
-func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
+func newListenContext(stk *stack.Stack, protocol *protocol, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
- stack: stk,
- rcvWnd: rcvWnd,
- hasher: sha1.New(),
- v6Only: v6Only,
- netProto: netProto,
- listenEP: listenEP,
- pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint),
+ stack: stk,
+ protocol: protocol,
+ rcvWnd: rcvWnd,
+ hasher: sha1.New(),
+ v6Only: v6Only,
+ netProto: netProto,
+ listenEP: listenEP,
}
for i := range l.nonce {
@@ -191,17 +180,9 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true
}
-func (l *listenContext) useSynCookies() bool {
- var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies
- if err := l.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil {
- panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err))
- }
- return bool(alwaysUseSynCookies) || (l.listenEP != nil && l.listenEP.synRcvdBacklogFull())
-}
-
// createConnectingEndpoint creates a new endpoint in a connecting state, with
// the connection parameters given by the arguments.
-func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) {
+func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
@@ -213,7 +194,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header
return nil, err
}
- n := newEndpoint(l.stack, netProto, queue)
+ n := newEndpoint(l.stack, l.protocol, netProto, queue)
n.ops.SetV6Only(l.v6Only)
n.TransportEndpointInfo.ID = s.id
n.boundNICID = s.nicID
@@ -244,10 +225,10 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header
// On success, a handshake h is returned with h.ep.mu held.
//
// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked.
-func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) {
+func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
- isn := generateSecureISN(s.id, l.stack.Clock(), l.stack.Seed())
+ isn := generateSecureISN(s.id, l.stack.Clock(), l.protocol.seqnumSecret)
ep, err := l.createConnectingEndpoint(s, opts, queue)
if err != nil {
return nil, err
@@ -271,18 +252,15 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q
return nil, &tcpip.ErrConnectionAborted{}
}
- l.addPendingEndpoint(ep)
// Propagate any inheritable options from the listening endpoint
// to the newly created endpoint.
- l.listenEP.propagateInheritableOptionsLocked(ep)
+ l.listenEP.propagateInheritableOptionsLocked(ep) // +checklocksforce
if !ep.reserveTupleLocked() {
ep.mu.Unlock()
ep.Close()
- l.removePendingEndpoint(ep)
-
return nil, &tcpip.ErrConnectionAborted{}
}
@@ -301,10 +279,6 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q
ep.mu.Unlock()
ep.Close()
- if l.listenEP != nil {
- l.removePendingEndpoint(ep)
- }
-
ep.drainClosingSegmentQueue()
return nil, err
@@ -323,7 +297,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q
// established endpoint is returned with e.mu held.
//
// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked.
-func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, tcpip.Error) {
+func (l *listenContext) performHandshake(s *segment, opts header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, tcpip.Error) {
h, err := l.startHandshake(s, opts, queue, owner)
if err != nil {
return nil, err
@@ -342,39 +316,12 @@ func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions,
return ep, nil
}
-func (l *listenContext) addPendingEndpoint(n *endpoint) {
- l.pendingMu.Lock()
- l.pendingEndpoints[n.TransportEndpointInfo.ID] = n
- l.pending.Add(1)
- l.pendingMu.Unlock()
-}
-
-func (l *listenContext) removePendingEndpoint(n *endpoint) {
- l.pendingMu.Lock()
- delete(l.pendingEndpoints, n.TransportEndpointInfo.ID)
- l.pending.Done()
- l.pendingMu.Unlock()
-}
-
-func (l *listenContext) closeAllPendingEndpoints() {
- l.pendingMu.Lock()
- for _, n := range l.pendingEndpoints {
- n.notifyProtocolGoroutine(notifyClose)
- }
- l.pendingMu.Unlock()
- l.pending.Wait()
-}
-
-// Precondition: h.ep.mu must be held.
// +checklocks:h.ep.mu
func (l *listenContext) cleanupFailedHandshake(h *handshake) {
e := h.ep
e.mu.Unlock()
e.Close()
e.notifyAborted()
- if l.listenEP != nil {
- l.removePendingEndpoint(e)
- }
e.drainClosingSegmentQueue()
e.h = nil
}
@@ -382,12 +329,9 @@ func (l *listenContext) cleanupFailedHandshake(h *handshake) {
// cleanupCompletedHandshake transfers any state from the completed handshake to
// the new endpoint.
//
-// Precondition: h.ep.mu must be held.
+// +checklocks:h.ep.mu
func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
e := h.ep
- if l.listenEP != nil {
- l.removePendingEndpoint(e)
- }
e.isConnectNotified = true
// Update the receive window scaling. We can't do it before the
@@ -399,47 +343,11 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
e.h = nil
}
-// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
-// listener has transitioned out of the listen state (accepted is the zero
-// value), the new endpoint is reset instead.
-func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) {
- e.mu.Lock()
- e.pendingAccepted.Add(1)
- e.mu.Unlock()
- defer e.pendingAccepted.Done()
-
- // Drop the lock before notifying to avoid deadlock in user-specified
- // callbacks.
- delivered := func() bool {
- e.acceptMu.Lock()
- defer e.acceptMu.Unlock()
- for {
- if e.accepted == (accepted{}) {
- return false
- }
- if e.accepted.endpoints.Len() == e.accepted.cap {
- e.acceptCond.Wait()
- continue
- }
-
- e.accepted.endpoints.PushBack(n)
- if !withSynCookie {
- atomic.AddInt32(&e.synRcvdCount, -1)
- }
- return true
- }
- }()
- if delivered {
- e.waiterQueue.Notify(waiter.ReadableEvents)
- } else {
- n.notifyProtocolGoroutine(notifyReset)
- }
-}
-
// propagateInheritableOptionsLocked propagates any options set on the listening
// endpoint to the newly created endpoint.
//
-// Precondition: e.mu and n.mu must be held.
+// +checklocks:e.mu
+// +checklocks:n.mu
func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
n.userTimeout = e.userTimeout
n.portFlags = e.portFlags
@@ -450,9 +358,9 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
// reserveTupleLocked reserves an accepted endpoint's tuple.
//
-// Preconditions:
-// * propagateInheritableOptionsLocked has been called.
-// * e.mu is held.
+// Precondition: e.propagateInheritableOptionsLocked has been called.
+//
+// +checklocks:e.mu
func (e *endpoint) reserveTupleLocked() bool {
dest := tcpip.FullAddress{
Addr: e.TransportEndpointInfo.ID.RemoteAddress,
@@ -487,70 +395,36 @@ func (e *endpoint) notifyAborted() {
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
-// handleSynSegment is called in its own goroutine once the listening endpoint
-// receives a SYN segment. It is responsible for completing the handshake and
-// queueing the new endpoint for acceptance.
-//
-// A limited number of these goroutines are allowed before TCP starts using SYN
-// cookies to accept connections.
-//
-// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
-func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) tcpip.Error {
- defer s.decRef()
-
- h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner)
- if err != nil {
- e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
- e.stats.FailedConnectionAttempts.Increment()
- atomic.AddInt32(&e.synRcvdCount, -1)
- return err
- }
+func (e *endpoint) acceptQueueIsFull() bool {
+ e.acceptMu.Lock()
+ full := e.acceptQueue.isFull()
+ e.acceptMu.Unlock()
+ return full
+}
- go func() {
- // Note that startHandshake returns a locked endpoint. The
- // force call here just makes it so.
- if err := h.complete(); err != nil { // +checklocksforce
- e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
- e.stats.FailedConnectionAttempts.Increment()
- ctx.cleanupFailedHandshake(h)
- atomic.AddInt32(&e.synRcvdCount, -1)
- return
- }
- ctx.cleanupCompletedHandshake(h)
- h.ep.startAcceptedLoop()
- e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- e.deliverAccepted(h.ep, false /*withSynCookie*/)
- }()
+// +stateify savable
+type acceptQueue struct {
+ // NB: this could be an endpointList, but ilist only permits endpoints to
+ // belong to one list at a time, and endpoints are already stored in the
+ // dispatcher's list.
+ endpoints list.List `state:".([]*endpoint)"`
- return nil
-}
+ // pendingEndpoints is a set of all endpoints for which a handshake is
+ // in progress.
+ pendingEndpoints map[*endpoint]struct{}
-func (e *endpoint) synRcvdBacklogFull() bool {
- e.acceptMu.Lock()
- acceptedCap := e.accepted.cap
- e.acceptMu.Unlock()
- // The capacity of the accepted queue would always be one greater than the
- // listen backlog. But, the SYNRCVD connections count is always checked
- // against the listen backlog value for Linux parity reason.
- // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
- //
- // We maintain an equality check here as the synRcvdCount is incremented
- // and compared only from a single listener context and the capacity of
- // the accepted queue can only increase by a new listen call.
- return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedCap-1
+ // capacity is the maximum number of endpoints that can be in endpoints.
+ capacity int
}
-func (e *endpoint) acceptQueueIsFull() bool {
- e.acceptMu.Lock()
- full := e.accepted != (accepted{}) && e.accepted.endpoints.Len() == e.accepted.cap
- e.acceptMu.Unlock()
- return full
+func (a *acceptQueue) isFull() bool {
+ return a.endpoints.Len() == a.capacity
}
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
//
-// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
+// +checklocks:e.mu
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Error {
e.rcvQueueInfo.rcvQueueMu.Lock()
rcvClosed := e.rcvQueueInfo.RcvClosed
@@ -578,11 +452,95 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
}
opts := parseSynSegmentOptions(s)
- if !ctx.useSynCookies() {
- s.incRef()
- atomic.AddInt32(&e.synRcvdCount, 1)
- return e.handleSynSegment(ctx, s, &opts)
+
+ useSynCookies, err := func() (bool, tcpip.Error) {
+ var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies
+ if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil {
+ panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err))
+ }
+ if alwaysUseSynCookies {
+ return true, nil
+ }
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+
+ // The capacity of the accepted queue would always be one greater than the
+ // listen backlog. But, the SYNRCVD connections count is always checked
+ // against the listen backlog value for Linux parity reason.
+ // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
+ if len(e.acceptQueue.pendingEndpoints) == e.acceptQueue.capacity-1 {
+ return true, nil
+ }
+
+ h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner)
+ if err != nil {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ return false, err
+ }
+
+ e.acceptQueue.pendingEndpoints[h.ep] = struct{}{}
+ e.pendingAccepted.Add(1)
+
+ go func() {
+ defer func() {
+ e.pendingAccepted.Done()
+
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+ delete(e.acceptQueue.pendingEndpoints, h.ep)
+ }()
+
+ // Note that startHandshake returns a locked endpoint. The force call
+ // here just makes it so.
+ if err := h.complete(); err != nil { // +checklocksforce
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ ctx.cleanupFailedHandshake(h)
+ return
+ }
+ ctx.cleanupCompletedHandshake(h)
+ h.ep.startAcceptedLoop()
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
+
+ // Deliver the endpoint to the accept queue.
+ //
+ // Drop the lock before notifying to avoid deadlock in user-specified
+ // callbacks.
+ delivered := func() bool {
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+ for {
+ // The listener is transitioning out of the Listen state; bail.
+ if e.acceptQueue.capacity == 0 {
+ return false
+ }
+ if e.acceptQueue.isFull() {
+ e.acceptCond.Wait()
+ continue
+ }
+
+ e.acceptQueue.endpoints.PushBack(h.ep)
+ return true
+ }
+ }()
+
+ if delivered {
+ e.waiterQueue.Notify(waiter.ReadableEvents)
+ } else {
+ h.ep.notifyProtocolGoroutine(notifyReset)
+ }
+ }()
+
+ return false, nil
+ }()
+ if err != nil {
+ return err
}
+ if !useSynCookies {
+ return nil
+ }
+
route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
if err != nil {
return err
@@ -600,10 +558,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
synOpts := header.TCPSynOptions{
WS: -1,
TS: opts.TS,
- TSVal: tcpTimeStamp(e.stack.Clock().NowMonotonic(), timeStampOffset(e.stack.Rand())),
TSEcr: opts.TSVal,
MSS: calculateAdvertisedMSS(e.userMSS, route),
}
+ if opts.TS {
+ offset := e.protocol.tsOffset(s.dstAddr, s.srcAddr)
+ now := e.stack.Clock().NowMonotonic()
+ synOpts.TSVal = offset.TSVal(now)
+ }
cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
fields := tcpFields{
id: s.id,
@@ -621,18 +583,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
return nil
case s.flags.Contains(header.TCPFlagAck):
- if e.acceptQueueIsFull() {
- // Silently drop the ack as the application can't accept
- // the connection at this point. The ack will be
- // retransmitted by the sender anyway and we can
- // complete the connection at the time of retransmit if
- // the backlog has space.
- e.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
- e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment()
- e.stack.Stats().DroppedPackets.Increment()
- return nil
- }
-
iss := s.ackNumber - 1
irs := s.sequenceNumber - 1
@@ -668,9 +618,27 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
// ACK was received from the sender.
return replyWithReset(e.stack, s, e.sendTOS, e.ttl)
}
+
+ // Keep hold of acceptMu until the new endpoint is in the accept queue (or
+ // if there is an error), to guarantee that we will keep our spot in the
+ // queue even if another handshake from the syn queue completes.
+ e.acceptMu.Lock()
+ if e.acceptQueue.isFull() {
+ // Silently drop the ack as the application can't accept
+ // the connection at this point. The ack will be
+ // retransmitted by the sender anyway and we can
+ // complete the connection at the time of retransmit if
+ // the backlog has space.
+ e.acceptMu.Unlock()
+ e.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
+ e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment()
+ e.stack.Stats().DroppedPackets.Increment()
+ return nil
+ }
+
e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment()
// Create newly accepted endpoint and deliver it.
- rcvdSynOptions := &header.TCPSynOptions{
+ rcvdSynOptions := header.TCPSynOptions{
MSS: mssTable[data],
// Disable Window scaling as original SYN is
// lost.
@@ -689,6 +657,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
n, err := ctx.createConnectingEndpoint(s, rcvdSynOptions, &waiter.Queue{})
if err != nil {
+ e.acceptMu.Unlock()
return err
}
@@ -700,6 +669,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
if !n.reserveTupleLocked() {
n.mu.Unlock()
+ e.acceptMu.Unlock()
n.Close()
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
@@ -717,6 +687,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
n.boundBindToDevice,
); err != nil {
n.mu.Unlock()
+ e.acceptMu.Unlock()
n.Close()
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
@@ -725,25 +696,22 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
}
n.isRegistered = true
-
- // clear the tsOffset for the newly created
- // endpoint as the Timestamp was already
- // randomly offset when the original SYN-ACK was
- // sent above.
- n.TSOffset = 0
+ n.TSOffset = n.protocol.tsOffset(s.dstAddr, s.srcAddr)
// Switch state to connected.
n.isConnectNotified = true
- n.transitionToStateEstablishedLocked(&handshake{
- ep: n,
- iss: iss,
- ackNum: irs + 1,
- rcvWnd: seqnum.Size(n.initialReceiveWindow()),
- sndWnd: s.window,
- rcvWndScale: e.rcvWndScaleForHandshake(),
- sndWndScale: rcvdSynOptions.WS,
- mss: rcvdSynOptions.MSS,
- })
+ h := handshake{
+ ep: n,
+ iss: iss,
+ ackNum: irs + 1,
+ rcvWnd: seqnum.Size(n.initialReceiveWindow()),
+ sndWnd: s.window,
+ rcvWndScale: e.rcvWndScaleForHandshake(),
+ sndWndScale: rcvdSynOptions.WS,
+ mss: rcvdSynOptions.MSS,
+ sampleRTTWithTSOnly: true,
+ }
+ h.transitionToStateEstablishedLocked(s)
// Requeue the segment if the ACK completing the handshake has more info
// to be procesed by the newly established endpoint.
@@ -752,20 +720,15 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
n.newSegmentWaker.Assert()
}
- // Do the delivery in a separate goroutine so
- // that we don't block the listen loop in case
- // the application is slow to accept or stops
- // accepting.
- //
- // NOTE: This won't result in an unbounded
- // number of goroutines as we do check before
- // entering here that there was at least some
- // space available in the backlog.
-
// Start the protocol goroutine.
n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- go e.deliverAccepted(n, true /*withSynCookie*/)
+
+ // Deliver the endpoint to the accept queue.
+ e.acceptQueue.endpoints.PushBack(n)
+ e.acceptMu.Unlock()
+
+ e.waiterQueue.Notify(waiter.ReadableEvents)
return nil
default:
@@ -779,17 +742,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.mu.Lock()
v6Only := e.ops.GetV6Only()
- ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto)
+ ctx := newListenContext(e.stack, e.protocol, e, rcvWnd, v6Only, e.NetProto)
defer func() {
- // Mark endpoint as closed. This will prevent goroutines running
- // handleSynSegment() from attempting to queue new connections
- // to the endpoint.
e.setEndpointState(StateClose)
- // Close any endpoints in SYN-RCVD state.
- ctx.closeAllPendingEndpoints()
-
// Do cleanup if needed.
e.completeWorkerLocked()
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 93ed161f9..80cd07218 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -30,6 +30,10 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// InitialRTO is the initial retransmission timeout.
+// https://github.com/torvalds/linux/blob/7c636d4d20f/include/net/tcp.h#L142
+const InitialRTO = time.Second
+
// maxSegmentsPerWake is the maximum number of segments to process in the main
// protocol goroutine per wake-up. Yielding [after this number of segments are
// processed] allows other events to be processed as well (e.g., timeouts,
@@ -105,6 +109,11 @@ type handshake struct {
// sendSYNOpts is the cached values for the SYN options to be sent.
sendSYNOpts header.TCPSynOptions
+
+ // sampleRTTWithTSOnly is true when the segment was retransmitted or we can't
+ // tell; then RTT can only be sampled when the incoming segment has timestamp
+ // options enabled.
+ sampleRTTWithTSOnly bool
}
func (e *endpoint) newHandshake() *handshake {
@@ -117,10 +126,12 @@ func (e *endpoint) newHandshake() *handshake {
h.resetState()
// Store reference to handshake state in endpoint.
e.h = h
+ // By the time handshake is created, e.ID is already initialized.
+ e.TSOffset = e.protocol.tsOffset(e.ID.LocalAddress, e.ID.RemoteAddress)
return h
}
-func (e *endpoint) newPassiveHandshake(isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) *handshake {
+func (e *endpoint) newPassiveHandshake(isn, irs seqnum.Value, opts header.TCPSynOptions, deferAccept time.Duration) *handshake {
h := e.newHandshake()
h.resetToSynRcvd(isn, irs, opts, deferAccept)
return h
@@ -150,20 +161,23 @@ func (h *handshake) resetState() {
h.flags = header.TCPFlagSyn
h.ackNum = 0
h.mss = 0
- h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Clock(), h.ep.stack.Seed())
+ h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Clock(), h.ep.protocol.seqnumSecret)
}
// generateSecureISN generates a secure Initial Sequence number based on the
// recommendation here https://tools.ietf.org/html/rfc6528#page-3.
func generateSecureISN(id stack.TransportEndpointID, clock tcpip.Clock, seed uint32) seqnum.Value {
isnHasher := jenkins.Sum32(seed)
- isnHasher.Write([]byte(id.LocalAddress))
- isnHasher.Write([]byte(id.RemoteAddress))
+ // Per hash.Hash.Writer:
+ //
+ // It never returns an error.
+ _, _ = isnHasher.Write([]byte(id.LocalAddress))
+ _, _ = isnHasher.Write([]byte(id.RemoteAddress))
portBuf := make([]byte, 2)
binary.LittleEndian.PutUint16(portBuf, id.LocalPort)
- isnHasher.Write(portBuf)
+ _, _ = isnHasher.Write(portBuf)
binary.LittleEndian.PutUint16(portBuf, id.RemotePort)
- isnHasher.Write(portBuf)
+ _, _ = isnHasher.Write(portBuf)
// The time period here is 64ns. This is similar to what linux uses
// generate a sequence number that overlaps less than one
// time per MSL (2 minutes).
@@ -190,7 +204,7 @@ func (h *handshake) effectiveRcvWndScale() uint8 {
// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD
// state.
-func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) {
+func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts header.TCPSynOptions, deferAccept time.Duration) {
h.active = false
h.state = handshakeSynRcvd
h.flags = header.TCPFlagSyn | header.TCPFlagAck
@@ -251,10 +265,10 @@ func (h *handshake) synSentState(s *segment) tcpip.Error {
rcvSynOpts := parseSynSegmentOptions(s)
// Remember if the Timestamp option was negotiated.
- h.ep.maybeEnableTimestamp(&rcvSynOpts)
+ h.ep.maybeEnableTimestamp(rcvSynOpts)
// Remember if the SACKPermitted option was negotiated.
- h.ep.maybeEnableSACKPermitted(&rcvSynOpts)
+ h.ep.maybeEnableSACKPermitted(rcvSynOpts)
// Remember the sequence we'll ack from now on.
h.ackNum = s.sequenceNumber + 1
@@ -266,8 +280,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error {
// and the handshake is completed.
if s.flags.Contains(header.TCPFlagAck) {
h.state = handshakeCompleted
-
- h.ep.transitionToStateEstablishedLocked(h)
+ h.transitionToStateEstablishedLocked(s)
h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale())
return nil
@@ -283,7 +296,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error {
synOpts := header.TCPSynOptions{
WS: int(h.effectiveRcvWndScale()),
TS: rcvSynOpts.TS,
- TSVal: h.ep.timestamp(),
+ TSVal: h.ep.tsValNow(),
TSEcr: h.ep.recentTimestamp(),
// We only send SACKPermitted if the other side indicated it
@@ -353,7 +366,7 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
TS: h.ep.SendTSOk,
- TSVal: h.ep.timestamp(),
+ TSVal: h.ep.tsValNow(),
TSEcr: h.ep.recentTimestamp(),
SACKPermitted: h.ep.SACKPermitted,
MSS: h.ep.amss,
@@ -402,9 +415,10 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
if h.ep.SendTSOk && s.parsedOptions.TS {
h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber)
}
+
h.state = handshakeCompleted
- h.ep.transitionToStateEstablishedLocked(h)
+ h.transitionToStateEstablishedLocked(s)
// Requeue the segment if the ACK completing the handshake has more info
// to be procesed by the newly established endpoint.
@@ -480,7 +494,7 @@ func (h *handshake) start() {
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
TS: true,
- TSVal: h.ep.timestamp(),
+ TSVal: h.ep.tsValNow(),
TSEcr: h.ep.recentTimestamp(),
SACKPermitted: bool(sackEnabled),
MSS: h.ep.amss,
@@ -522,7 +536,7 @@ func (h *handshake) complete() tcpip.Error {
defer s.Done()
// Initialize the resend timer.
- timer, err := newBackoffTimer(h.ep.stack.Clock(), time.Second, MaxRTO, resendWaker.Assert)
+ timer, err := newBackoffTimer(h.ep.stack.Clock(), InitialRTO, MaxRTO, resendWaker.Assert)
if err != nil {
return err
}
@@ -557,6 +571,10 @@ func (h *handshake) complete() tcpip.Error {
ack: h.ackNum,
rcvWnd: h.rcvWnd,
}, h.sendSYNOpts)
+ // If we have ever retransmitted the SYN-ACK or
+ // SYN segment, we should only measure RTT if
+ // TS option is present.
+ h.sampleRTTWithTSOnly = true
}
case wakerForNotification:
@@ -564,6 +582,9 @@ func (h *handshake) complete() tcpip.Error {
if (n&notifyClose)|(n&notifyAbort) != 0 {
return &tcpip.ErrAborted{}
}
+ if n&notifyShutdown != 0 {
+ return &tcpip.ErrConnectionReset{}
+ }
if n&notifyDrain != 0 {
for !h.ep.segmentQueue.empty() {
s := h.ep.segmentQueue.dequeue()
@@ -600,6 +621,40 @@ func (h *handshake) complete() tcpip.Error {
return nil
}
+// transitionToStateEstablisedLocked transitions the endpoint of the handshake
+// to an established state given the last segment received from peer. It also
+// initializes sender/receiver.
+func (h *handshake) transitionToStateEstablishedLocked(s *segment) {
+ // Transfer handshake state to TCP connection. We disable
+ // receive window scaling if the peer doesn't support it
+ // (indicated by a negative send window scale).
+ h.ep.snd = newSender(h.ep, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
+
+ now := h.ep.stack.Clock().NowMonotonic()
+
+ var rtt time.Duration
+ if h.ep.SendTSOk && s.parsedOptions.TSEcr != 0 {
+ rtt = h.ep.elapsed(now, s.parsedOptions.TSEcr)
+ }
+ if !h.sampleRTTWithTSOnly && rtt == 0 {
+ rtt = now.Sub(h.startTime)
+ }
+
+ if rtt > 0 {
+ h.ep.snd.updateRTO(rtt)
+ }
+
+ h.ep.rcvQueueInfo.rcvQueueMu.Lock()
+ h.ep.rcv = newReceiver(h.ep, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale())
+ // Bootstrap the auto tuning algorithm. Starting at zero will
+ // result in a really large receive window after the first auto
+ // tuning adjustment.
+ h.ep.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = int(h.rcvWnd)
+ h.ep.rcvQueueInfo.rcvQueueMu.Unlock()
+
+ h.ep.setEndpointState(StateEstablished)
+}
+
type backoffTimer struct {
timeout time.Duration
maxTimeout time.Duration
@@ -873,7 +928,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
// Ref: https://tools.ietf.org/html/rfc7323#section-5.4.
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeNOP(options[offset:])
- offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:])
+ offset += header.EncodeTSOption(e.tsValNow(), e.recentTimestamp(), options[offset:])
}
if e.SACKPermitted && len(sackBlocks) > 0 {
offset += header.EncodeNOP(options[offset:])
@@ -965,26 +1020,6 @@ func (e *endpoint) completeWorkerLocked() {
}
}
-// transitionToStateEstablisedLocked transitions a given endpoint
-// to an established state using the handshake parameters provided.
-// It also initializes sender/receiver.
-func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
- // Transfer handshake state to TCP connection. We disable
- // receive window scaling if the peer doesn't support it
- // (indicated by a negative send window scale).
- e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
-
- e.rcvQueueInfo.rcvQueueMu.Lock()
- e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale())
- // Bootstrap the auto tuning algorithm. Starting at zero will
- // result in a really large receive window after the first auto
- // tuning adjustment.
- e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = int(h.rcvWnd)
- e.rcvQueueInfo.rcvQueueMu.Unlock()
-
- e.setEndpointState(StateEstablished)
-}
-
// transitionToStateCloseLocked ensures that the endpoint is
// cleaned up from the transport demuxer, "before" moving to
// StateClose. This will ensure that no packet will be
@@ -1286,7 +1321,7 @@ func (e *endpoint) disableKeepaliveTimer() {
// protocolMainLoopDone is called at the end of protocolMainLoop.
// +checklocksrelease:e.mu
-func (e *endpoint) protocolMainLoopDone(closeTimer tcpip.Timer, closeWaker *sleep.Waker) {
+func (e *endpoint) protocolMainLoopDone(closeTimer tcpip.Timer) {
if e.snd != nil {
e.snd.resendTimer.cleanup()
e.snd.probeTimer.cleanup()
@@ -1314,7 +1349,7 @@ func (e *endpoint) protocolMainLoopDone(closeTimer tcpip.Timer, closeWaker *slee
// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
// goroutine and is responsible for sending segments and handling received
// segments.
-func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error {
+func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) {
var (
closeTimer tcpip.Timer
closeWaker sleep.Waker
@@ -1331,8 +1366,8 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.hardError = err
e.workerCleanup = true
- e.protocolMainLoopDone(closeTimer, &closeWaker)
- return err
+ e.protocolMainLoopDone(closeTimer)
+ return
}
}
@@ -1559,8 +1594,8 @@ loop:
// just want to terminate the loop and cleanup the
// endpoint.
cleanupOnError(nil)
- e.protocolMainLoopDone(closeTimer, &closeWaker)
- return nil
+ e.protocolMainLoopDone(closeTimer)
+ return
case StateTimeWait:
fallthrough
case StateClose:
@@ -1568,8 +1603,8 @@ loop:
default:
if err := funcs[v].f(); err != nil {
cleanupOnError(err)
- e.protocolMainLoopDone(closeTimer, &closeWaker)
- return nil
+ e.protocolMainLoopDone(closeTimer)
+ return
}
}
}
@@ -1592,21 +1627,19 @@ loop:
// Handle any StateError transition from StateTimeWait.
if e.EndpointState() == StateError {
cleanupOnError(nil)
- e.protocolMainLoopDone(closeTimer, &closeWaker)
- return nil
+ e.protocolMainLoopDone(closeTimer)
+ return
}
e.transitionToStateCloseLocked()
- e.protocolMainLoopDone(closeTimer, &closeWaker)
+ e.protocolMainLoopDone(closeTimer)
// A new SYN was received during TIME_WAIT and we need to abort
// the timewait and redirect the segment to the listener queue
if reuseTW != nil {
reuseTW()
}
-
- return nil
}
// handleTimeWaitSegments processes segments received during TIME_WAIT
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 044123185..6a798e980 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -15,12 +15,10 @@
package tcp
import (
- "container/list"
"encoding/binary"
"fmt"
"io"
"math"
- "math/rand"
"runtime"
"strings"
"sync/atomic"
@@ -188,6 +186,8 @@ const (
// say TIME_WAIT.
notifyTickleWorker
notifyError
+ // notifyShutdown means that a connecting socket was shutdown.
+ notifyShutdown
)
// SACKInfo holds TCP SACK related information for a given endpoint.
@@ -204,6 +204,8 @@ type SACKInfo struct {
}
// ReceiveErrors collect segment receive errors within transport layer.
+//
+// +stateify savable
type ReceiveErrors struct {
tcpip.ReceiveErrors
@@ -233,6 +235,8 @@ type ReceiveErrors struct {
}
// SendErrors collect segment send errors within the transport layer.
+//
+// +stateify savable
type SendErrors struct {
tcpip.SendErrors
@@ -256,6 +260,8 @@ type SendErrors struct {
}
// Stats holds statistics about the endpoint.
+//
+// +stateify savable
type Stats struct {
// SegmentsReceived is the number of TCP segments received that
// the transport layer successfully parsed.
@@ -310,15 +316,6 @@ type rcvQueueInfo struct {
rcvQueue segmentList `state:"wait"`
}
-// +stateify savable
-type accepted struct {
- // NB: this could be an endpointList, but ilist only permits endpoints to
- // belong to one list at a time, and endpoints are already stored in the
- // dispatcher's list.
- endpoints list.List `state:".([]*endpoint)"`
- cap int
-}
-
// endpoint represents a TCP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -334,7 +331,7 @@ type accepted struct {
// The following three mutexes can be acquired independent of e.mu but if
// acquired with e.mu then e.mu must be acquired first.
//
-// e.acceptMu -> protects accepted.
+// e.acceptMu -> Protects e.acceptQueue.
// e.rcvQueueMu -> Protects e.rcvQueue and associated fields.
// e.sndQueueMu -> Protects the e.sndQueue and associated fields.
// e.lastErrorMu -> Protects the lastError field.
@@ -378,6 +375,7 @@ type endpoint struct {
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
+ protocol *protocol `state:"manual"`
waiterQueue *waiter.Queue `state:"wait"`
uniqueID uint64
@@ -497,10 +495,6 @@ type endpoint struct {
// and dropped when it is.
segmentQueue segmentQueue `state:"wait"`
- // synRcvdCount is the number of connections for this endpoint that are
- // in SYN-RCVD state; this is only accessed atomically.
- synRcvdCount int32
-
// userMSS if non-zero is the MSS value explicitly set by the user
// for this endpoint using the TCP_MAXSEG setsockopt.
userMSS uint16
@@ -573,7 +567,8 @@ type endpoint struct {
// accepted is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
- accepted accepted
+ // +checklocks:acceptMu
+ acceptQueue acceptQueue
// The following are only used from the protocol goroutine, and
// therefore don't need locks to protect them.
@@ -606,8 +601,7 @@ type endpoint struct {
gso stack.GSO
- // TODO(b/142022063): Add ability to save and restore per endpoint stats.
- stats Stats `state:"nosave"`
+ stats Stats
// tcpLingerTimeout is the maximum amount of a time a socket
// a socket stays in TIME_WAIT state before being marked
@@ -803,9 +797,10 @@ type keepalive struct {
waker sleep.Waker `state:"nosave"`
}
-func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
+func newEndpoint(s *stack.Stack, protocol *protocol, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
e := &endpoint{
- stack: s,
+ stack: s,
+ protocol: protocol,
TransportEndpointInfo: stack.TransportEndpointInfo{
NetProto: netProto,
TransProto: header.TCPProtocolNumber,
@@ -874,7 +869,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
}
e.segmentQueue.ep = e
- e.TSOffset = timeStampOffset(e.stack.Rand())
+
e.acceptCond = sync.NewCond(&e.acceptMu)
e.keepalive.timer.init(e.stack.Clock(), &e.keepalive.waker)
@@ -903,7 +898,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// Check if there's anything in the accepted queue.
if (mask & waiter.ReadableEvents) != 0 {
e.acceptMu.Lock()
- if e.accepted.endpoints.Len() != 0 {
+ if e.acceptQueue.endpoints.Len() != 0 {
result |= waiter.ReadableEvents
}
e.acceptMu.Unlock()
@@ -1086,20 +1081,20 @@ func (e *endpoint) closeNoShutdownLocked() {
// handshake but not yet been delivered to the application.
func (e *endpoint) closePendingAcceptableConnectionsLocked() {
e.acceptMu.Lock()
- acceptedCopy := e.accepted
- e.accepted = accepted{}
- e.acceptMu.Unlock()
-
- if acceptedCopy == (accepted{}) {
- return
+ // Close any endpoints in SYN-RCVD state.
+ for n := range e.acceptQueue.pendingEndpoints {
+ n.notifyProtocolGoroutine(notifyClose)
}
-
- e.acceptCond.Broadcast()
-
+ e.acceptQueue.pendingEndpoints = nil
// Reset all connections that are waiting to be accepted.
- for n := acceptedCopy.endpoints.Front(); n != nil; n = n.Next() {
+ for n := e.acceptQueue.endpoints.Front(); n != nil; n = n.Next() {
n.Value.(*endpoint).notifyProtocolGoroutine(notifyReset)
}
+ e.acceptQueue.endpoints.Init()
+ e.acceptMu.Unlock()
+
+ e.acceptCond.Broadcast()
+
// Wait for reset of all endpoints that are still waiting to be delivered to
// the now closed accepted.
e.pendingAccepted.Wait()
@@ -1717,6 +1712,27 @@ func (e *endpoint) OnSetReceiveBufferSize(rcvBufSz, oldSz int64) (newSz int64) {
return rcvBufSz
}
+// OnSetSendBufferSize implements tcpip.SocketOptionsHandler.OnSetSendBufferSize.
+func (e *endpoint) OnSetSendBufferSize(sz int64) int64 {
+ atomic.StoreUint32(&e.sndQueueInfo.TCPSndBufState.AutoTuneSndBufDisabled, 1)
+ return sz
+}
+
+// WakeupWriters implements tcpip.SocketOptionsHandler.WakeupWriters.
+func (e *endpoint) WakeupWriters() {
+ e.LockUser()
+ defer e.UnlockUser()
+
+ sendBufferSize := e.getSendBufferSize()
+ e.sndQueueInfo.sndQueueMu.Lock()
+ notify := (sendBufferSize - e.sndQueueInfo.SndBufUsed) >= e.sndQueueInfo.SndBufUsed>>1
+ e.sndQueueInfo.sndQueueMu.Unlock()
+
+ if notify {
+ e.waiterQueue.Notify(waiter.WritableEvents)
+ }
+}
+
// SetSockOptInt sets a socket option.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
// Lower 2 bits represents ECN bits. RFC 3168, section 23.1
@@ -2038,7 +2054,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
case *tcpip.OriginalDestinationOption:
e.LockUser()
ipt := e.stack.IPTables()
- addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto)
+ addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto, ProtocolNumber)
e.UnlockUser()
if err != nil {
return err
@@ -2177,7 +2193,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
portBuf := make([]byte, 2)
binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort)
- h := jenkins.Sum32(e.stack.Seed())
+ h := jenkins.Sum32(e.protocol.portOffsetSecret)
for _, s := range [][]byte{
[]byte(e.ID.LocalAddress),
[]byte(e.ID.RemoteAddress),
@@ -2329,6 +2345,9 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
e.segmentQueue.mu.Unlock()
e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0)
e.setEndpointState(StateEstablished)
+ // Set the new auto tuned send buffer size after entering
+ // established state.
+ e.ops.SetSendBufferSize(e.computeTCPSendBufferSize(), false /* notify */)
}
if run {
@@ -2355,6 +2374,18 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error {
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
e.LockUser()
defer e.UnlockUser()
+
+ if e.EndpointState().connecting() {
+ // When calling shutdown(2) on a connecting socket, the endpoint must
+ // enter the error state. But this logic cannot belong to the shutdownLocked
+ // method because that method is called during a close(2) (and closing a
+ // connecting socket is not an error).
+ e.resetConnectionLocked(&tcpip.ErrConnectionReset{})
+ e.notifyProtocolGoroutine(notifyShutdown)
+ e.waiterQueue.Notify(waiter.WritableEvents | waiter.EventHUp | waiter.EventErr)
+ return nil
+ }
+
return e.shutdownLocked(flags)
}
@@ -2455,21 +2486,22 @@ func (e *endpoint) listen(backlog int) tcpip.Error {
if e.EndpointState() == StateListen && !e.closed {
e.acceptMu.Lock()
defer e.acceptMu.Unlock()
- if e.accepted == (accepted{}) {
- // listen is called after shutdown.
- e.accepted.cap = backlog
- e.shutdownFlags = 0
- e.rcvQueueInfo.rcvQueueMu.Lock()
- e.rcvQueueInfo.RcvClosed = false
- e.rcvQueueInfo.rcvQueueMu.Unlock()
- } else {
- // Adjust the size of the backlog iff we can fit
- // existing pending connections into the new one.
- if e.accepted.endpoints.Len() > backlog {
- return &tcpip.ErrInvalidEndpointState{}
- }
- e.accepted.cap = backlog
+
+ // Adjust the size of the backlog iff we can fit
+ // existing pending connections into the new one.
+ if e.acceptQueue.endpoints.Len() > backlog {
+ return &tcpip.ErrInvalidEndpointState{}
}
+ e.acceptQueue.capacity = backlog
+
+ if e.acceptQueue.pendingEndpoints == nil {
+ e.acceptQueue.pendingEndpoints = make(map[*endpoint]struct{})
+ }
+
+ e.shutdownFlags = 0
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ e.rcvQueueInfo.RcvClosed = false
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
// Notify any blocked goroutines that they can attempt to
// deliver endpoints again.
@@ -2505,8 +2537,11 @@ func (e *endpoint) listen(backlog int) tcpip.Error {
// may be pre-populated with some previously accepted (but not Accepted)
// endpoints.
e.acceptMu.Lock()
- if e.accepted == (accepted{}) {
- e.accepted.cap = backlog
+ if e.acceptQueue.pendingEndpoints == nil {
+ e.acceptQueue.pendingEndpoints = make(map[*endpoint]struct{})
+ }
+ if e.acceptQueue.capacity == 0 {
+ e.acceptQueue.capacity = backlog
}
e.acceptMu.Unlock()
@@ -2546,8 +2581,8 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.
// Get the new accepted endpoint.
var n *endpoint
e.acceptMu.Lock()
- if element := e.accepted.endpoints.Front(); element != nil {
- n = e.accepted.endpoints.Remove(element).(*endpoint)
+ if element := e.acceptQueue.endpoints.Front(); element != nil {
+ n = e.acceptQueue.endpoints.Remove(element).(*endpoint)
}
e.acceptMu.Unlock()
if n == nil {
@@ -2763,13 +2798,20 @@ func (e *endpoint) updateSndBufferUsage(v int) {
e.sndQueueInfo.sndQueueMu.Lock()
notify := e.sndQueueInfo.SndBufUsed >= sendBufferSize>>1
e.sndQueueInfo.SndBufUsed -= v
+
+ // Get the new send buffer size with auto tuning, but do not set it
+ // unless we decide to notify the writers.
+ newSndBufSz := e.computeTCPSendBufferSize()
+
// We only notify when there is half the sendBufferSize available after
// a full buffer event occurs. This ensures that we don't wake up
// writers to queue just 1-2 segments and go back to sleep.
- notify = notify && e.sndQueueInfo.SndBufUsed < sendBufferSize>>1
+ notify = notify && e.sndQueueInfo.SndBufUsed < int(newSndBufSz)>>1
e.sndQueueInfo.sndQueueMu.Unlock()
if notify {
+ // Set the new send buffer size calculated from auto tuning.
+ e.ops.SetSendBufferSize(newSndBufSz, false /* notify */)
e.waiterQueue.Notify(waiter.WritableEvents)
}
}
@@ -2873,46 +2915,29 @@ func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value,
// maybeEnableTimestamp marks the timestamp option enabled for this endpoint if
// the SYN options indicate that timestamp option was negotiated. It also
// initializes the recentTS with the value provided in synOpts.TSval.
-func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
+func (e *endpoint) maybeEnableTimestamp(synOpts header.TCPSynOptions) {
if synOpts.TS {
e.SendTSOk = true
e.setRecentTimestamp(synOpts.TSVal)
}
}
-// timestamp returns the timestamp value to be used in the TSVal field of the
-// timestamp option for outgoing TCP segments for a given endpoint.
-func (e *endpoint) timestamp() uint32 {
- return tcpTimeStamp(e.stack.Clock().NowMonotonic(), e.TSOffset)
+func (e *endpoint) tsVal(now tcpip.MonotonicTime) uint32 {
+ return e.TSOffset.TSVal(now)
}
-// tcpTimeStamp returns a timestamp offset by the provided offset. This is
-// not inlined above as it's used when SYN cookies are in use and endpoint
-// is not created at the time when the SYN cookie is sent.
-func tcpTimeStamp(curTime tcpip.MonotonicTime, offset uint32) uint32 {
- d := curTime.Sub(tcpip.MonotonicTime{})
- return uint32(d.Milliseconds()) + offset
+func (e *endpoint) tsValNow() uint32 {
+ return e.tsVal(e.stack.Clock().NowMonotonic())
}
-// timeStampOffset returns a randomized timestamp offset to be used when sending
-// timestamp values in a timestamp option for a TCP segment.
-func timeStampOffset(rng *rand.Rand) uint32 {
- // Initialize a random tsOffset that will be added to the recentTS
- // everytime the timestamp is sent when the Timestamp option is enabled.
- //
- // See https://tools.ietf.org/html/rfc7323#section-5.4 for details on
- // why this is required.
- //
- // NOTE: This is not completely to spec as normally this should be
- // initialized in a manner analogous to how sequence numbers are
- // randomized per connection basis. But for now this is sufficient.
- return rng.Uint32()
+func (e *endpoint) elapsed(now tcpip.MonotonicTime, tsEcr uint32) time.Duration {
+ return e.TSOffset.Elapsed(now, tsEcr)
}
// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint
// if the SYN options indicate that the SACK option was negotiated and the TCP
// stack is configured to enable TCP SACK option.
-func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) {
+func (e *endpoint) maybeEnableSACKPermitted(synOpts header.TCPSynOptions) {
var v tcpip.TCPSACKEnabled
if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil {
// Stack doesn't support SACK. So just return.
@@ -2974,6 +2999,8 @@ func (e *endpoint) completeStateLocked() stack.TCPEndpointState {
}
s.Sender.RACKState = e.snd.rc.TCPRACKState
+ s.Sender.RetransmitTS = e.snd.retransmitTS
+ s.Sender.SpuriousRecovery = e.snd.spuriousRecovery
return s
}
@@ -3091,3 +3118,36 @@ func GetTCPReceiveBufferLimits(s tcpip.StackHandler) tcpip.ReceiveBufferSizeOpti
Max: ss.Max,
}
}
+
+// computeTCPSendBufferSize implements auto tuning of send buffer size and
+// returns the new send buffer size.
+func (e *endpoint) computeTCPSendBufferSize() int64 {
+ curSndBufSz := int64(e.getSendBufferSize())
+
+ // Auto tuning is disabled when the user explicitly sets the send
+ // buffer size with SO_SNDBUF option.
+ if disabled := atomic.LoadUint32(&e.sndQueueInfo.TCPSndBufState.AutoTuneSndBufDisabled); disabled == 1 {
+ return curSndBufSz
+ }
+
+ const packetOverheadFactor = 2
+ curMSS := e.snd.MaxPayloadSize
+ numSeg := InitialCwnd
+ if numSeg < e.snd.SndCwnd {
+ numSeg = e.snd.SndCwnd
+ }
+
+ // SndCwnd indicates the number of segments that can be sent. This means
+ // that the sender can send upto #SndCwnd segments and the send buffer
+ // size should be set to SndCwnd*MSS to accommodate sending of all the
+ // segments.
+ newSndBufSz := int64(numSeg * curMSS * packetOverheadFactor)
+ if newSndBufSz < curSndBufSz {
+ return curSndBufSz
+ }
+ if ss := GetTCPSendBufferLimits(e.stack); int64(ss.Max) < newSndBufSz {
+ newSndBufSz = int64(ss.Max)
+ }
+
+ return newSndBufSz
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 952ccacdd..94072a115 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -100,7 +100,7 @@ func (e *endpoint) beforeSave() {
}
// saveEndpoints is invoked by stateify.
-func (a *accepted) saveEndpoints() []*endpoint {
+func (a *acceptQueue) saveEndpoints() []*endpoint {
acceptedEndpoints := make([]*endpoint, a.endpoints.Len())
for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() {
acceptedEndpoints[i] = e.Value.(*endpoint)
@@ -109,7 +109,7 @@ func (a *accepted) saveEndpoints() []*endpoint {
}
// loadEndpoints is invoked by stateify.
-func (a *accepted) loadEndpoints(acceptedEndpoints []*endpoint) {
+func (a *acceptQueue) loadEndpoints(acceptedEndpoints []*endpoint) {
for _, ep := range acceptedEndpoints {
a.endpoints.PushBack(ep)
}
@@ -170,6 +170,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
snd.probeTimer.init(s.Clock(), &snd.probeWaker)
}
e.stack = s
+ e.protocol = protocolFromStack(s)
e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits)
e.segmentQueue.thaw()
epState := EndpointState(e.origEndpointState)
@@ -250,7 +251,9 @@ func (e *endpoint) Resume(s *stack.Stack) {
go func() {
connectedLoading.Wait()
bind()
- backlog := e.accepted.cap
+ e.acceptMu.Lock()
+ backlog := e.acceptQueue.capacity
+ e.acceptMu.Unlock()
if err := e.Listen(backlog); err != nil {
panic("endpoint listening failed: " + err.String())
}
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 2e709ed78..128ef09e3 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -54,7 +54,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
maxInFlight: maxInFlight,
handler: handler,
inFlight: make(map[stack.TransportEndpointID]struct{}),
- listen: newListenContext(s, nil /* listenEP */, seqnum.Size(rcvWnd), true, 0),
+ listen: newListenContext(s, protocolFromStack(s), nil /* listenEP */, seqnum.Size(rcvWnd), true, 0),
}
}
@@ -152,7 +152,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
}
f := r.forwarder
- ep, err := f.listen.performHandshake(r.segment, &header.TCPSynOptions{
+ ep, err := f.listen.performHandshake(r.segment, header.TCPSynOptions{
MSS: r.synOptions.MSS,
WS: r.synOptions.WS,
TS: r.synOptions.TS,
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 18b834243..e4410ad93 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -23,8 +23,10 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/header/parse"
+ "gvisor.dev/gvisor/pkg/tcpip/internal/tcp"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/raw"
@@ -49,10 +51,6 @@ const (
// MaxBufferSize is the largest size a receive/send buffer can grow to.
MaxBufferSize = 4 << 20 // 4MB
- // MaxUnprocessedSegments is the maximum number of unprocessed segments
- // that can be queued for a given endpoint.
- MaxUnprocessedSegments = 300
-
// DefaultTCPLingerTimeout is the amount of time that sockets linger in
// FIN_WAIT_2 state before being marked closed.
DefaultTCPLingerTimeout = 60 * time.Second
@@ -96,6 +94,11 @@ type protocol struct {
maxRetries uint32
synRetries uint8
dispatcher dispatcher
+
+ // The following secrets are initialized once and stay unchanged after.
+ seqnumSecret uint32
+ portOffsetSecret uint32
+ tsOffsetSecret uint32
}
// Number returns the tcp protocol number.
@@ -105,7 +108,7 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
// NewEndpoint creates a new tcp endpoint.
func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
- return newEndpoint(p.stack, netProto, waiterQueue), nil
+ return newEndpoint(p.stack, p, netProto, waiterQueue), nil
}
// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently
@@ -156,6 +159,24 @@ func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID,
return stack.UnknownDestinationPacketHandled
}
+func (p *protocol) tsOffset(src, dst tcpip.Address) tcp.TSOffset {
+ // Initialize a random tsOffset that will be added to the recentTS
+ // everytime the timestamp is sent when the Timestamp option is enabled.
+ //
+ // See https://tools.ietf.org/html/rfc7323#section-5.4 for details on
+ // why this is required.
+ //
+ // TODO(https://gvisor.dev/issues/6473): This is not really secure as
+ // it does not use the recommended algorithm linked above.
+ h := jenkins.Sum32(p.tsOffsetSecret)
+ // Per hash.Hash.Writer:
+ //
+ // It never returns an error.
+ _, _ = h.Write([]byte(src))
+ _, _ = h.Write([]byte(dst))
+ return tcp.NewTSOffset(h.Sum32())
+}
+
// replyWithReset replies to the given segment with a reset segment.
//
// If the passed TTL is 0, then the route's default TTL will be used.
@@ -292,22 +313,26 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip
case *tcpip.TCPMinRTOOption:
p.mu.Lock()
+ defer p.mu.Unlock()
if *v < 0 {
p.minRTO = MinRTO
+ } else if minRTO := time.Duration(*v); minRTO <= p.maxRTO {
+ p.minRTO = minRTO
} else {
- p.minRTO = time.Duration(*v)
+ return &tcpip.ErrInvalidOptionValue{}
}
- p.mu.Unlock()
return nil
case *tcpip.TCPMaxRTOOption:
p.mu.Lock()
+ defer p.mu.Unlock()
if *v < 0 {
p.maxRTO = MaxRTO
+ } else if maxRTO := time.Duration(*v); maxRTO >= p.minRTO {
+ p.maxRTO = maxRTO
} else {
- p.maxRTO = time.Duration(*v)
+ return &tcpip.ErrInvalidOptionValue{}
}
- p.mu.Unlock()
return nil
case *tcpip.TCPMaxRetriesOption:
@@ -479,7 +504,15 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol {
maxRTO: MaxRTO,
maxRetries: MaxRetries,
recovery: tcpip.TCPRACKLossDetection,
+ seqnumSecret: s.Rand().Uint32(),
+ portOffsetSecret: s.Rand().Uint32(),
+ tsOffsetSecret: s.Rand().Uint32(),
}
p.dispatcher.init(s.Rand(), runtime.GOMAXPROCS(0))
return &p
}
+
+// protocolFromStack retrieves the tcp.protocol instance from stack s.
+func protocolFromStack(s *stack.Stack) *protocol {
+ return s.TransportProtocolInstance(ProtocolNumber).(*protocol)
+}
diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go
index 0da4eafaa..3b055c294 100644
--- a/pkg/tcpip/transport/tcp/rack.go
+++ b/pkg/tcpip/transport/tcp/rack.go
@@ -80,7 +80,6 @@ func (rc *rackControl) init(snd *sender, iss seqnum.Value) {
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2
func (rc *rackControl) update(seg *segment, ackSeg *segment) {
rtt := rc.snd.ep.stack.Clock().NowMonotonic().Sub(seg.xmitTime)
- tsOffset := rc.snd.ep.TSOffset
// If the ACK is for a retransmitted packet, do not update if it is a
// spurious inference which is determined by below checks:
@@ -92,7 +91,7 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) {
// step 2
if seg.xmitCount > 1 {
if ackSeg.parsedOptions.TS && ackSeg.parsedOptions.TSEcr != 0 {
- if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, tsOffset) {
+ if ackSeg.parsedOptions.TSEcr < rc.snd.ep.tsVal(seg.xmitTime) {
return
}
}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 9ce8fcae9..90e493978 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -477,7 +477,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) {
// segments. This ensures that we always leave some space for the inorder
// segments to arrive allowing pending segments to be processed and
// delivered to the user.
- if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && r.PendingBufUsed < int(rcvBufSize)>>2 {
+ if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && (r.PendingBufUsed+int(segLen)) < int(rcvBufSize)>>2 {
r.ep.rcvQueueInfo.rcvQueueMu.Lock()
r.PendingBufUsed += s.segMemSize()
r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go
index 8a026ec46..e47a07030 100644
--- a/pkg/tcpip/transport/tcp/rcv_test.go
+++ b/pkg/tcpip/transport/tcp/rcv_test.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package rcv_test
+package tcp_test
import (
"testing"
diff --git a/pkg/tcpip/transport/tcp/segment_test.go b/pkg/tcpip/transport/tcp/segment_test.go
index 2e6ea06f5..2d5fdda19 100644
--- a/pkg/tcpip/transport/tcp/segment_test.go
+++ b/pkg/tcpip/transport/tcp/segment_test.go
@@ -34,7 +34,7 @@ func checkSegmentSize(t *testing.T, name string, seg *segment, want segmentSizeW
DataSize: seg.data.Size(),
SegMemSize: seg.segMemSize(),
}
- if diff := cmp.Diff(got, want); diff != "" {
+ if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("%s differs (-want +got):\n%s", name, diff)
}
}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 92a66f17e..4377f07a0 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -144,6 +144,15 @@ type sender struct {
// probeTimer and probeWaker are used to schedule PTO for RACK TLP algorithm.
probeTimer timer `state:"nosave"`
probeWaker sleep.Waker `state:"nosave"`
+
+ // spuriousRecovery indicates whether the sender entered recovery
+ // spuriously as described in RFC3522 Section 3.2.
+ spuriousRecovery bool
+
+ // retransmitTS is the timestamp at which the sender sends retransmitted
+ // segment after entering an RTO for the first time as described in
+ // RFC3522 Section 3.2.
+ retransmitTS uint32
}
// rtt is a synchronization wrapper used to appease stateify. See the comment
@@ -382,6 +391,9 @@ func (s *sender) updateRTO(rtt time.Duration) {
if s.RTO < s.minRTO {
s.RTO = s.minRTO
}
+ if s.RTO > s.maxRTO {
+ s.RTO = s.maxRTO
+ }
}
// resendSegment resends the first unacknowledged segment.
@@ -422,6 +434,13 @@ func (s *sender) retransmitTimerExpired() bool {
return true
}
+ // Initialize the variables used to detect spurious recovery after
+ // entering RTO.
+ //
+ // See: https://www.rfc-editor.org/rfc/rfc3522.html#section-3.2 Step 1.
+ s.spuriousRecovery = false
+ s.retransmitTS = 0
+
// TODO(b/147297758): Band-aid fix, retransmitTimer can fire in some edge cases
// when writeList is empty. Remove this once we have a proper fix for this
// issue.
@@ -492,6 +511,10 @@ func (s *sender) retransmitTimerExpired() bool {
s.leaveRecovery()
}
+ // Record retransmitTS if the sender is not in recovery as per:
+ // https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2
+ s.recordRetransmitTS()
+
s.state = tcpip.RTORecovery
s.cc.HandleRTOExpired()
@@ -955,6 +978,13 @@ func (s *sender) sendData() {
}
func (s *sender) enterRecovery() {
+ // Initialize the variables used to detect spurious recovery after
+ // entering recovery.
+ //
+ // See: https://www.rfc-editor.org/rfc/rfc3522.html#section-3.2 Step 1.
+ s.spuriousRecovery = false
+ s.retransmitTS = 0
+
s.FastRecovery.Active = true
// Save state to reflect we're now in fast recovery.
//
@@ -969,6 +999,11 @@ func (s *sender) enterRecovery() {
s.FastRecovery.MaxCwnd = s.SndCwnd + s.Outstanding
s.FastRecovery.HighRxt = s.SndUna
s.FastRecovery.RescueRxt = s.SndUna
+
+ // Record retransmitTS if the sender is not in recovery as per:
+ // https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2
+ s.recordRetransmitTS()
+
if s.ep.SACKPermitted {
s.state = tcpip.SACKRecovery
s.ep.stack.Stats().TCP.SACKRecovery.Increment()
@@ -1144,13 +1179,15 @@ func (s *sender) isDupAck(seg *segment) bool {
// Iterate the writeList and update RACK for each segment which is newly acked
// either cumulatively or selectively. Loop through the segments which are
// sacked, and update the RACK related variables and check for reordering.
+// Returns true when the DSACK block has been detected in the received ACK.
//
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
// steps 2 and 3.
-func (s *sender) walkSACK(rcvdSeg *segment) {
+func (s *sender) walkSACK(rcvdSeg *segment) bool {
s.rc.setDSACKSeen(false)
// Look for DSACK block.
+ hasDSACK := false
idx := 0
n := len(rcvdSeg.parsedOptions.SACKBlocks)
if checkDSACK(rcvdSeg) {
@@ -1164,10 +1201,11 @@ func (s *sender) walkSACK(rcvdSeg *segment) {
s.rc.setDSACKSeen(true)
idx = 1
n--
+ hasDSACK = true
}
if n == 0 {
- return
+ return hasDSACK
}
// Sort the SACK blocks. The first block is the most recent unacked
@@ -1190,6 +1228,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) {
seg = seg.Next()
}
}
+ return hasDSACK
}
// checkDSACK checks if a DSACK is reported.
@@ -1236,6 +1275,85 @@ func checkDSACK(rcvdSeg *segment) bool {
return false
}
+func (s *sender) recordRetransmitTS() {
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2
+ //
+ // The Eifel detection algorithm is used, only upon initiation of loss
+ // recovery, i.e., when either the timeout-based retransmit or the fast
+ // retransmit is sent. The Eifel detection algorithm MUST NOT be
+ // reinitiated after loss recovery has already started. In particular,
+ // it must not be reinitiated upon subsequent timeouts for the same
+ // segment, and not upon retransmitting segments other than the oldest
+ // outstanding segment, e.g., during selective loss recovery.
+ if s.inRecovery() {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2
+ //
+ // Set a "RetransmitTS" variable to the value of the Timestamp Value
+ // field of the Timestamps option included in the retransmit sent when
+ // loss recovery is initiated. A TCP sender must ensure that
+ // RetransmitTS does not get overwritten as loss recovery progresses,
+ // e.g., in case of a second timeout and subsequent second retransmit of
+ // the same octet.
+ s.retransmitTS = s.ep.tsValNow()
+}
+
+func (s *sender) detectSpuriousRecovery(hasDSACK bool, tsEchoReply uint32) {
+ // Return if the sender has already detected spurious recovery.
+ if s.spuriousRecovery {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 4
+ //
+ // If the value of the Timestamp Echo Reply field of the acceptable ACK's
+ // Timestamps option is smaller than the value of RetransmitTS, then
+ // proceed to next step, else return.
+ if tsEchoReply >= s.retransmitTS {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 5
+ //
+ // If the acceptable ACK carries a DSACK option [RFC2883], then return.
+ if hasDSACK {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 5
+ //
+ // If during the lifetime of the TCP connection the TCP sender has
+ // previously received an ACK with a DSACK option, or the acceptable ACK
+ // does not acknowledge all outstanding data, then proceed to next step,
+ // else return.
+ numDSACK := s.ep.stack.Stats().TCP.SegmentsAckedWithDSACK.Value()
+ if numDSACK == 0 && s.SndUna == s.SndNxt {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 6
+ //
+ // If the loss recovery has been initiated with a timeout-based
+ // retransmit, then set
+ // SpuriousRecovery <- SPUR_TO (equal 1),
+ // else set
+ // SpuriousRecovery <- dupacks+1
+ // Set the spurious recovery variable to true as we do not differentiate
+ // between fast, SACK or RTO recovery.
+ s.spuriousRecovery = true
+ s.ep.stack.Stats().TCP.SpuriousRecovery.Increment()
+}
+
+// Check if the sender is in RTORecovery, FastRecovery or SACKRecovery state.
+func (s *sender) inRecovery() bool {
+ if s.state == tcpip.RTORecovery || s.state == tcpip.FastRecovery || s.state == tcpip.SACKRecovery {
+ return true
+ }
+ return false
+}
+
// handleRcvdSegment is called when a segment is received; it is responsible for
// updating the send-related state.
func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
@@ -1251,6 +1369,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
// Insert SACKBlock information into our scoreboard.
+ hasDSACK := false
if s.ep.SACKPermitted {
for _, sb := range rcvdSeg.parsedOptions.SACKBlocks {
// Only insert the SACK block if the following holds
@@ -1285,7 +1404,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// RACK.fack, then the corresponding packet has been
// reordered and RACK.reord is set to TRUE.
if s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 {
- s.walkSACK(rcvdSeg)
+ hasDSACK = s.walkSACK(rcvdSeg)
}
s.SetPipe()
}
@@ -1342,10 +1461,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// some new data, i.e., only if it advances the left edge of
// the send window.
if s.ep.SendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 {
- // TSVal/Ecr values sent by Netstack are at a millisecond
- // granularity.
- elapsed := time.Duration(s.ep.timestamp()-rcvdSeg.parsedOptions.TSEcr) * time.Millisecond
- s.updateRTO(elapsed)
+ s.updateRTO(s.ep.elapsed(s.ep.stack.Clock().NowMonotonic(), rcvdSeg.parsedOptions.TSEcr))
}
if s.shouldSchedulePTO() {
@@ -1415,12 +1531,14 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
ackLeft -= datalen
}
- // Update the send buffer usage and notify potential waiters.
- s.ep.updateSndBufferUsage(int(acked))
-
// Clear SACK information for all acked data.
s.ep.scoreboard.Delete(s.SndUna)
+ // Detect if the sender entered recovery spuriously.
+ if s.inRecovery() {
+ s.detectSpuriousRecovery(hasDSACK, rcvdSeg.parsedOptions.TSEcr)
+ }
+
// If we are not in fast recovery then update the congestion
// window based on the number of acknowledged packets.
if !s.FastRecovery.Active {
@@ -1437,6 +1555,9 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
}
+ // Update the send buffer usage and notify potential waiters.
+ s.ep.updateSndBufferUsage(int(acked))
+
// It is possible for s.outstanding to drop below zero if we get
// a retransmit timeout, reset outstanding to zero but later
// get an ack that cover previously sent data.
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
index 89e9fb886..0d36d0dd0 100644
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -33,7 +33,6 @@ const (
tsOptionSize = 12
maxTCPOptionSize = 40
mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload
- latency = 5 * time.Millisecond
)
func setStackTCPRecovery(t *testing.T, c *context.Context, recovery int) {
@@ -163,7 +162,10 @@ func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, en
if !enableRACK {
setStackTCPRecovery(t, c, 0)
}
- createConnectedWithSACKAndTS(c)
+ // The delay should be below initial RTO (1s) otherwise retransimission
+ // will start. Choose a relatively large value so that estimated RTT
+ // keeps high even after a few rounds of undelayed RTT samples.
+ c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true}, 800*time.Millisecond /* delay */)
data := make([]byte, numPackets*maxPayload)
for i := range data {
@@ -181,9 +183,6 @@ func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, en
for i := 0; i < numPackets; i++ {
c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
bytesRead += maxPayload
- // This delay is added to increase RTT as low RTT can cause TLP
- // before sending ACK.
- time.Sleep(latency)
}
return data
@@ -1060,16 +1059,17 @@ func TestRACKWithWindowFull(t *testing.T) {
for i := 0; i < numPkts; i++ {
c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
bytesRead += maxPayload
- if i == 0 {
- // Send ACK for the first packet to establish RTT.
- c.SendAck(seq, maxPayload)
- }
}
- // SACK for #10 packet.
- start := c.IRS.Add(seqnum.Size(1 + (numPkts-1)*maxPayload))
+ // Expect retransmission of last packet due to TLP.
+ c.ReceiveAndCheckPacketWithOptions(data, (numPkts-1)*maxPayload, maxPayload, tsOptionSize)
+
+ // SACK for first and last packet.
+ start := c.IRS.Add(seqnum.Size(maxPayload))
end := start.Add(seqnum.Size(maxPayload))
- c.SendAckWithSACK(seq, 2*maxPayload, []header.SACKBlock{{start, end}})
+ dsackStart := c.IRS.Add(seqnum.Size(1 + (numPkts-1)*maxPayload))
+ dsackEnd := dsackStart.Add(seqnum.Size(maxPayload))
+ c.SendAckWithSACK(seq, 2*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}})
var info tcpip.TCPInfoOption
if err := c.EP.GetSockOpt(&info); err != nil {
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index 83e0653b9..896249d2d 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -23,6 +23,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -35,13 +36,13 @@ import (
// SACKPermitted option enabled if the stack in the context has the SACK support
// enabled.
func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()})
+ return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()})
}
// createConnectedWithSACKAndTS creates and connects c.ep with the SACK & TS
// option enabled if the stack in the context has SACK and TS enabled.
func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true})
+ return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true})
}
func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) {
@@ -108,7 +109,7 @@ func TestSackDisabledConnect(t *testing.T) {
setStackSACKPermitted(t, c, sackEnabled)
setStackTCPRecovery(t, c, 0)
- rep := c.CreateConnectedWithOptions(header.TCPSynOptions{})
+ rep := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{})
data := []byte{1, 2, 3}
@@ -170,7 +171,7 @@ func TestSackPermittedAccept(t *testing.T) {
setStackSACKPermitted(t, c, sackEnabled)
setStackTCPRecovery(t, c, 0)
- rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted})
+ rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted})
// Now verify no SACK blocks are
// received when sack is disabled.
data := []byte{1, 2, 3}
@@ -244,7 +245,7 @@ func TestSackDisabledAccept(t *testing.T) {
setStackSACKPermitted(t, c, sackEnabled)
setStackTCPRecovery(t, c, 0)
- rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ rep := c.AcceptWithOptionsNoDelay(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
// Now verify no SACK blocks are
// received when sack is disabled.
@@ -702,3 +703,257 @@ func TestRecoveryEntry(t *testing.T) {
t.Error(err)
}
}
+
+func verifySpuriousRecoveryMetric(t *testing.T, c *context.Context, numSpuriousRecovery uint64) {
+ t.Helper()
+
+ metricPollFn := func() error {
+ tcpStats := c.Stack().Stats().TCP
+ stats := []struct {
+ stat *tcpip.StatCounter
+ name string
+ want uint64
+ }{
+ {tcpStats.SpuriousRecovery, "stats.TCP.SpuriousRecovery", numSpuriousRecovery},
+ }
+ for _, s := range stats {
+ if got, want := s.stat.Value(), s.want; got != want {
+ return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
+ }
+ }
+ return nil
+ }
+
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
+ }
+}
+
+func checkReceivedPacket(t *testing.T, c *context.Context, tcpHdr header.TCP, bytesRead uint32, b, data []byte) {
+ payloadLen := uint32(len(tcpHdr.Payload()))
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPSeqNum(uint32(c.IRS)+1+bytesRead),
+ checker.TCPAckNum(context.TestInitialSequenceNumber+1),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
+ ),
+ )
+ pdata := data[bytesRead : bytesRead+payloadLen]
+ if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) {
+ t.Fatalf("got data = %v, want = %v", p, pdata)
+ }
+}
+
+func buildTSOptionFromHeader(tcpHdr header.TCP) []byte {
+ parsedOpts := tcpHdr.ParsedOptions()
+ tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
+ header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:])
+ return tsOpt[:]
+}
+
+func TestDetectSpuriousRecoveryWithRTO(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan struct{})
+ c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) {
+ if s.Sender.RetransmitTS == 0 {
+ t.Fatalf("RetransmitTS did not get updated, got: 0 want > 0")
+ }
+ if !s.Sender.SpuriousRecovery {
+ t.Fatalf("Spurious recovery was not detected")
+ }
+ close(probeDone)
+ })
+
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+ numPackets := 5
+ data := make([]byte, numPackets*maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+ // Write the data.
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ var options []byte
+ var bytesRead uint32
+ for i := 0; i < numPackets; i++ {
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ checkReceivedPacket(t, c, tcpHdr, bytesRead, b, data)
+
+ // Get options only for the first packet. This will be sent with
+ // the ACK to indicate the acknowledgement is for the original
+ // packet.
+ if i == 0 && c.TimeStampEnabled {
+ options = buildTSOptionFromHeader(tcpHdr)
+ }
+ bytesRead += uint32(len(tcpHdr.Payload()))
+ }
+
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ // Expect #5 segment with TLP.
+ c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize)
+
+ // Expect #1 segment because of RTO.
+ c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize)
+
+ info := tcpip.TCPInfoOption{}
+ if err := c.EP.GetSockOpt(&info); err != nil {
+ t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err)
+ }
+
+ if info.CcState != tcpip.RTORecovery {
+ t.Fatalf("Loss recovery did not happen, got: %v want: %v", info.CcState, tcpip.RTORecovery)
+ }
+
+ // Acknowledge the data.
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seq,
+ AckNum: c.IRS.Add(1 + seqnum.Size(maxPayload)),
+ RcvWnd: rcvWnd,
+ TCPOpts: options,
+ })
+
+ // Wait for the probe function to finish processing the
+ // ACK before the test completes.
+ <-probeDone
+
+ verifySpuriousRecoveryMetric(t, c, 1 /* numSpuriousRecovery */)
+}
+
+func TestSACKDetectSpuriousRecoveryWithDupACK(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ numAck := 0
+ probeDone := make(chan struct{})
+ c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) {
+ if numAck < 3 {
+ numAck++
+ return
+ }
+
+ if s.Sender.RetransmitTS == 0 {
+ t.Fatalf("RetransmitTS did not get updated, got: 0 want > 0")
+ }
+ if !s.Sender.SpuriousRecovery {
+ t.Fatalf("Spurious recovery was not detected")
+ }
+ close(probeDone)
+ })
+
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+ numPackets := 5
+ data := make([]byte, numPackets*maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+ // Write the data.
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ var options []byte
+ var bytesRead uint32
+ for i := 0; i < numPackets; i++ {
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ checkReceivedPacket(t, c, tcpHdr, bytesRead, b, data)
+
+ // Get options only for the first packet. This will be sent with
+ // the ACK to indicate the acknowledgement is for the original
+ // packet.
+ if i == 0 && c.TimeStampEnabled {
+ options = buildTSOptionFromHeader(tcpHdr)
+ }
+ bytesRead += uint32(len(tcpHdr.Payload()))
+ }
+
+ // Receive the retransmitted packet after TLP.
+ c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize)
+
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ // Send ACK for #3 and #4 segments to avoid entering TLP.
+ start := c.IRS.Add(3*maxPayload + 1)
+ end := start.Add(2 * maxPayload)
+ c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}})
+
+ c.SendAck(seq, 0 /* bytesReceived */)
+ c.SendAck(seq, 0 /* bytesReceived */)
+
+ // Receive the retransmitted packet after three duplicate ACKs.
+ c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize)
+
+ info := tcpip.TCPInfoOption{}
+ if err := c.EP.GetSockOpt(&info); err != nil {
+ t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err)
+ }
+
+ if info.CcState != tcpip.SACKRecovery {
+ t.Fatalf("Loss recovery did not happen, got: %v want: %v", info.CcState, tcpip.SACKRecovery)
+ }
+
+ // Acknowledge the data.
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seq,
+ AckNum: c.IRS.Add(1 + seqnum.Size(maxPayload)),
+ RcvWnd: rcvWnd,
+ TCPOpts: options,
+ })
+
+ // Wait for the probe function to finish processing the
+ // ACK before the test completes.
+ <-probeDone
+
+ verifySpuriousRecoveryMetric(t, c, 1 /* numSpuriousRecovery */)
+}
+
+func TestNoSpuriousRecoveryWithDSACK(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+ numPackets := 5
+ data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
+
+ // Receive the retransmitted packet after TLP.
+ c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize)
+
+ // Send ACK for #3 and #4 segments to avoid entering TLP.
+ start := c.IRS.Add(3*maxPayload + 1)
+ end := start.Add(2 * maxPayload)
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}})
+
+ c.SendAck(seq, 0 /* bytesReceived */)
+ c.SendAck(seq, 0 /* bytesReceived */)
+
+ // Receive the retransmitted packet after three duplicate ACKs.
+ c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize)
+
+ // Acknowledge the data with DSACK for #1 segment.
+ start = c.IRS.Add(maxPayload + 1)
+ end = start.Add(2 * maxPayload)
+ seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ c.SendAckWithSACK(seq, 6*maxPayload, []header.SACKBlock{{start, end}})
+
+ verifySpuriousRecoveryMetric(t, c, 0 /* numSpuriousRecovery */)
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 031f01357..6f1ee3816 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -28,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
@@ -1381,8 +1382,12 @@ func TestListenerReadinessOnEvent(t *testing.T) {
if err := s.CreateNIC(id, ep); err != nil {
t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err)
}
- if err := s.AddAddress(id, ipv4.ProtocolNumber, context.StackAddr); err != nil {
- t.Fatalf("AddAddress(%d, ipv4.ProtocolNumber, %s): %s", id, context.StackAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(context.StackAddr).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{
{Destination: header.IPv4EmptySubnet, NIC: id},
@@ -1651,6 +1656,71 @@ func TestConnectBindToDevice(t *testing.T) {
}
}
+func TestShutdownConnectingSocket(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ shutdownMode tcpip.ShutdownFlags
+ }{
+ {"ShutdownRead", tcpip.ShutdownRead},
+ {"ShutdownWrite", tcpip.ShutdownWrite},
+ {"ShutdownReadWrite", tcpip.ShutdownRead | tcpip.ShutdownWrite},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create an endpoint, don't handshake because we want to interfere with
+ // the handshake process.
+ c.Create(-1)
+
+ waitEntry, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ // Start connection attempt.
+ addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, c.EP.Connect(addr)); d != "" {
+ t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d)
+ }
+
+ // Check the SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
+
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+
+ if err := c.EP.Shutdown(test.shutdownMode); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
+
+ // The endpoint internal state is updated immediately.
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+
+ select {
+ case <-ch:
+ default:
+ t.Fatal("endpoint was not notified")
+ }
+
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, &tcpip.ErrConnectionReset{})
+
+ // If the endpoint is not properly shutdown, it'll re-attempt to connect
+ // by sending another ACK packet.
+ c.CheckNoPacketTimeout("got an unexpected packet", tcp.InitialRTO+(500*time.Millisecond))
+ })
+ }
+}
+
func TestSynSent(t *testing.T) {
for _, test := range []struct {
name string
@@ -1674,7 +1744,7 @@ func TestSynSent(t *testing.T) {
addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
err := c.EP.Connect(addr)
- if d := cmp.Diff(err, &tcpip.ErrConnectStarted{}); d != "" {
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d)
}
@@ -1990,7 +2060,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
)
// Cause a FIN to be generated.
- c.EP.Shutdown(tcpip.ShutdownWrite)
+ if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
// Make sure we get the FIN but DON't ACK IT.
checker.IPv4(t, c.GetPacket(),
@@ -2006,7 +2078,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
// Cause a RST to be generated by closing the read end now since we have
// unread data.
- c.EP.Shutdown(tcpip.ShutdownRead)
+ if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
// Make sure we get the RST
checker.IPv4(t, c.GetPacket(),
@@ -2127,6 +2201,214 @@ func TestFullWindowReceive(t *testing.T) {
)
}
+func TestSmallReceiveBufferReadiness(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ })
+
+ ep := loopback.New()
+ if testing.Verbose() {
+ ep = sniffer.New(ep)
+ }
+
+ const nicID = 1
+ nicOpts := stack.NICOptions{Name: "nic1"}
+ if err := s.CreateNICWithOptions(nicID, ep, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %s", nicOpts, err)
+ }
+
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address("\x7f\x00\x00\x01"),
+ PrefixLen: 8,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) failed: %s", nicID, protocolAddr, err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00")
+ if err != nil {
+ t.Fatalf("tcpip.NewSubnet failed: %s", err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: subnet,
+ NIC: nicID,
+ },
+ })
+ }
+
+ listenerEntry, listenerCh := waiter.NewChannelEntry(nil)
+ var listenerWQ waiter.Queue
+ listener, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ defer listener.Close()
+ listenerWQ.EventRegister(&listenerEntry, waiter.ReadableEvents)
+ defer listenerWQ.EventUnregister(&listenerEntry)
+
+ if err := listener.Bind(tcpip.FullAddress{}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ if err := listener.Listen(1); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ localAddress, err := listener.GetLocalAddress()
+ if err != nil {
+ t.Fatalf("GetLocalAddress failed: %s", err)
+ }
+
+ for i := 8; i > 0; i /= 2 {
+ size := int64(i << 10)
+ t.Run(fmt.Sprintf("size=%d", size), func(t *testing.T) {
+ var clientWQ waiter.Queue
+ client, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ defer client.Close()
+ switch err := client.Connect(localAddress).(type) {
+ case nil:
+ t.Fatal("Connect returned nil error")
+ case *tcpip.ErrConnectStarted:
+ default:
+ t.Fatalf("Connect failed: %s", err)
+ }
+
+ <-listenerCh
+ server, serverWQ, err := listener.Accept(nil)
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+ defer server.Close()
+
+ client.SocketOptions().SetReceiveBufferSize(size, true)
+ // Send buffer size doesn't seem to affect this test.
+ // server.SocketOptions().SetSendBufferSize(size, true)
+
+ clientEntry, clientCh := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientEntry, waiter.ReadableEvents)
+ defer clientWQ.EventUnregister(&clientEntry)
+
+ serverEntry, serverCh := waiter.NewChannelEntry(nil)
+ serverWQ.EventRegister(&serverEntry, waiter.WritableEvents)
+ defer serverWQ.EventUnregister(&serverEntry)
+
+ var total int64
+ for {
+ var b [64 << 10]byte
+ var r bytes.Reader
+ r.Reset(b[:])
+ switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) {
+ case nil:
+ t.Logf("wrote %d bytes", n)
+ total += n
+ continue
+ case *tcpip.ErrWouldBlock:
+ select {
+ case <-serverCh:
+ continue
+ case <-time.After(100 * time.Millisecond):
+ // Well and truly full.
+ t.Logf("send and receive queues are full")
+ }
+ default:
+ t.Fatalf("Write failed: %s", err)
+ }
+ break
+ }
+ t.Logf("wrote %d bytes in total", total)
+
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+
+ var b [64 << 10]byte
+ var r bytes.Reader
+ r.Reset(b[:])
+ if err := func() error {
+ var total int64
+ defer t.Logf("wrote %d bytes in total", total)
+ for r.Len() != 0 {
+ switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) {
+ case nil:
+ t.Logf("wrote %d bytes", n)
+ total += n
+ case *tcpip.ErrWouldBlock:
+ for {
+ t.Logf("waiting on server")
+ select {
+ case <-serverCh:
+ case <-time.After(time.Second):
+ if readiness := server.Readiness(waiter.WritableEvents); readiness != 0 {
+ t.Logf("server.Readiness(%b) = %b but channel not signaled", waiter.WritableEvents, readiness)
+ }
+ continue
+ }
+ break
+ }
+ default:
+ return fmt.Errorf("server.Write failed: %s", err)
+ }
+ }
+ if err := server.Shutdown(tcpip.ShutdownWrite); err != nil {
+ return fmt.Errorf("server.Shutdown failed: %s", err)
+ }
+ t.Logf("server end shutdown done")
+ return nil
+ }(); err != nil {
+ t.Error(err)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+
+ if err := func() error {
+ total := 0
+ defer t.Logf("read %d bytes in total", total)
+ for {
+ switch res, err := client.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) {
+ case nil:
+ t.Logf("read %d bytes", res.Count)
+ total += res.Count
+ t.Logf("read total %d bytes till now", total)
+ case *tcpip.ErrClosedForReceive:
+ return nil
+ case *tcpip.ErrWouldBlock:
+ for {
+ t.Logf("waiting on client")
+ select {
+ case <-clientCh:
+ case <-time.After(time.Second):
+ if readiness := client.Readiness(waiter.ReadableEvents); readiness != 0 {
+ return fmt.Errorf("client.Readiness(%b) = %b but channel not signaled", waiter.ReadableEvents, readiness)
+ }
+ continue
+ }
+ break
+ }
+ default:
+ return fmt.Errorf("client.Write failed: %s", err)
+ }
+ }
+ }(); err != nil {
+ t.Error(err)
+ }
+ }()
+ })
+ }
+}
+
// Test the stack receive window advertisement on receiving segments smaller than
// segment overhead. It tests for the right edge of the window to not grow when
// the endpoint is not being read from.
@@ -2143,7 +2425,7 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
}
- c.AcceptWithOptions(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS})
+ c.AcceptWithOptionsNoDelay(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS})
// Bump up the receive buffer size such that, when the receive window grows,
// the scaled window exceeds maxUint16.
@@ -2535,7 +2817,7 @@ func TestScaledWindowAccept(t *testing.T) {
// Do 3-way handshake.
// wndScale expected is 3 as 65535 * 3 * 2 < 65535 * 2^3 but > 65535 *2 *2
- c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ c.PassiveConnectWithOptions(100, 3 /* wndScale */, header.TCPSynOptions{MSS: defaultIPv4MSS}, 0 /* delay */)
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
@@ -3532,6 +3814,12 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
}
+ // Wait for the connection to timeout after MaxRetries retransmits.
+ initRTO := time.Second
+ minRTOOpt := tcpip.TCPMinRTOOption(initRTO)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
+ }
c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
@@ -3554,8 +3842,6 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
),
)
}
- // Wait for the connection to timeout after MaxRetries retransmits.
- initRTO := 1 * time.Second
select {
case <-notifyCh:
case <-time.After((2 << numRetries) * initRTO):
@@ -3590,9 +3876,13 @@ func TestMaxRTO(t *testing.T) {
defer c.Cleanup()
rto := 1 * time.Second
- opt := tcpip.TCPMaxRTOOption(rto)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ minRTOOpt := tcpip.TCPMinRTOOption(rto / 2)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
+ }
+ maxRTOOpt := tcpip.TCPMaxRTOOption(rto)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &maxRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, maxRTOOpt, maxRTOOpt, err)
}
c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
@@ -3618,8 +3908,8 @@ func TestMaxRTO(t *testing.T) {
checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
- if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() {
- t.Errorf("Retransmit interval not capped to MaxRTO.\n")
+ if elapsed := time.Since(start); elapsed.Round(time.Second).Seconds() != rto.Seconds() {
+ t.Errorf("Retransmit interval not capped to MaxRTO(%s). %s", rto, elapsed)
}
}
}
@@ -3670,6 +3960,10 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ minRTOOpt := tcpip.TCPMinRTOOption(time.Second)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
+ }
c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
// Disabling PMTU discovery causes all packets sent from this socket to
@@ -4736,13 +5030,17 @@ func makeStack() (*stack.Stack, tcpip.Error) {
}
for _, ct := range []struct {
- number tcpip.NetworkProtocolNumber
- address tcpip.Address
+ number tcpip.NetworkProtocolNumber
+ addrWithPrefix tcpip.AddressWithPrefix
}{
- {ipv4.ProtocolNumber, context.StackAddr},
- {ipv6.ProtocolNumber, context.StackV6Addr},
+ {ipv4.ProtocolNumber, context.StackAddrWithPrefix},
+ {ipv6.ProtocolNumber, context.StackV6AddrWithPrefix},
} {
- if err := s.AddAddress(1, ct.number, ct.address); err != nil {
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ct.number,
+ AddressWithPrefix: ct.addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
return nil, err
}
}
@@ -4946,7 +5244,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
t.Fatalf("got s.SetPortRange(%d, %d) = %s, want = nil", start, end, err)
}
for i := start; i <= end; i++ {
- if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil {
+ if err := makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil {
t.Fatalf("Bind(%d) failed: %s", i, err)
}
}
@@ -6304,7 +6602,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
- c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}, 0 /* delay */)
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
@@ -6385,7 +6683,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// maximum buffer size defined above.
c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
- rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
+ rawEP := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, WS: 4})
// NOTE: The timestamp values in the sent packets are meaningless to the
// peer so we just increment the timestamp value by 1 every batch as we
@@ -6515,7 +6813,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// maximum buffer size used by stack.
c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
- rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
+ rawEP := c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, WS: 4})
tsVal := rawEP.TSVal
rawEP.NextSeqNum--
rawEP.SendPacketWithTS(nil, tsVal)
@@ -7430,6 +7728,11 @@ func TestTCPUserTimeout(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ initRTO := 1 * time.Second
+ minRTOOpt := tcpip.TCPMinRTOOption(initRTO)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
+ }
c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
@@ -7440,7 +7743,6 @@ func TestTCPUserTimeout(t *testing.T) {
// Ensure that on the next retransmit timer fire, the user timeout has
// expired.
- initRTO := 1 * time.Second
userTimeout := initRTO / 2
v := tcpip.TCPUserTimeoutOption(userTimeout)
if err := c.EP.SetSockOpt(&v); err != nil {
@@ -7954,6 +8256,151 @@ func TestSetStackTimeWaitReuse(t *testing.T) {
}
}
+func TestHandshakeRTT(t *testing.T) {
+ type testCase struct {
+ connect bool
+ tsEnabled bool
+ useCookie bool
+ retrans bool
+ delay time.Duration
+ wantRTT time.Duration
+ }
+ var testCases []testCase
+ for _, connect := range []bool{false, true} {
+ for _, tsEnabled := range []bool{false, true} {
+ for _, useCookie := range []bool{false, true} {
+ for _, retrans := range []bool{false, true} {
+ if connect && useCookie {
+ continue
+ }
+ delay := 800 * time.Millisecond
+ if retrans {
+ delay = 1200 * time.Millisecond
+ }
+ wantRTT := delay
+ // If syncookie is enabled, sample RTT only when TS option is enabled.
+ if !retrans && useCookie && !tsEnabled {
+ wantRTT = 0
+ }
+ // If retransmitted, sample RTT only when TS option is enabled.
+ if retrans && !tsEnabled {
+ wantRTT = 0
+ }
+ testCases = append(testCases, testCase{connect, tsEnabled, useCookie, retrans, delay, wantRTT})
+ }
+ }
+ }
+ }
+ for _, tt := range testCases {
+ tt := tt
+ t.Run(fmt.Sprintf("connect=%t,TS=%t,cookie=%t,retrans=%t)", tt.connect, tt.tsEnabled, tt.useCookie, tt.retrans), func(t *testing.T) {
+ t.Parallel()
+ c := context.New(t, defaultMTU)
+ if tt.useCookie {
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
+ }
+ }
+ synOpts := header.TCPSynOptions{}
+ if tt.tsEnabled {
+ synOpts.TS = true
+ synOpts.TSVal = 42
+ }
+ if tt.connect {
+ c.CreateConnectedWithOptions(synOpts, tt.delay)
+ } else {
+ synOpts.MSS = defaultIPv4MSS
+ synOpts.WS = -1
+ c.AcceptWithOptions(-1, synOpts, tt.delay)
+ }
+ var info tcpip.TCPInfoOption
+ if err := c.EP.GetSockOpt(&info); err != nil {
+ t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err)
+ }
+ if got := info.RTT.Round(tt.wantRTT); got != tt.wantRTT {
+ t.Fatalf("got info.RTT=%s, expect %s", got, tt.wantRTT)
+ }
+ if info.RTTVar != 0 && tt.wantRTT == 0 {
+ t.Fatalf("got info.RTTVar=%s, expect 0", info.RTTVar)
+ }
+ if info.RTTVar == 0 && tt.wantRTT != 0 {
+ t.Fatalf("got info.RTTVar=0, expect non zero")
+ }
+ })
+ }
+}
+
+func TestSetRTO(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ minRTO, maxRTO := tcpRTOMinMax(t, c)
+ for _, tt := range []struct {
+ name string
+ RTO time.Duration
+ minRTO time.Duration
+ maxRTO time.Duration
+ err tcpip.Error
+ }{
+ {
+ name: "invalid minRTO",
+ minRTO: maxRTO + time.Second,
+ err: &tcpip.ErrInvalidOptionValue{},
+ },
+ {
+ name: "invalid maxRTO",
+ maxRTO: minRTO - time.Millisecond,
+ err: &tcpip.ErrInvalidOptionValue{},
+ },
+ {
+ name: "valid minRTO",
+ minRTO: maxRTO - time.Second,
+ },
+ {
+ name: "valid maxRTO",
+ maxRTO: minRTO + time.Millisecond,
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ var opt tcpip.SettableTransportProtocolOption
+ if tt.minRTO > 0 {
+ min := tcpip.TCPMinRTOOption(tt.minRTO)
+ opt = &min
+ }
+ if tt.maxRTO > 0 {
+ max := tcpip.TCPMaxRTOOption(tt.maxRTO)
+ opt = &max
+ }
+ err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt)
+ if got, want := err, tt.err; got != want {
+ t.Fatalf("c.Stack().SetTransportProtocolOption(TCP, &%T(%v)) = %v, want = %v", opt, opt, got, want)
+ }
+ if tt.err == nil {
+ minRTO, maxRTO := tcpRTOMinMax(t, c)
+ if tt.minRTO > 0 && tt.minRTO != minRTO {
+ t.Fatalf("got minRTO = %s, want %s", minRTO, tt.minRTO)
+ }
+ if tt.maxRTO > 0 && tt.maxRTO != maxRTO {
+ t.Fatalf("got maxRTO = %s, want %s", maxRTO, tt.maxRTO)
+ }
+ }
+ })
+ }
+}
+
+func tcpRTOMinMax(t *testing.T, c *context.Context) (time.Duration, time.Duration) {
+ t.Helper()
+ var minOpt tcpip.TCPMinRTOOption
+ var maxOpt tcpip.TCPMaxRTOOption
+ if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &minOpt); err != nil {
+ t.Fatalf("c.Stack().TransportProtocolOption(TCP, %T): %s", minOpt, err)
+ }
+ if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &maxOpt); err != nil {
+ t.Fatalf("c.Stack().TransportProtocolOption(TCP, %T): %s", maxOpt, err)
+ }
+ return time.Duration(minOpt), time.Duration(maxOpt)
+}
+
// generateRandomPayload generates a random byte slice of the specified length
// causing a fatal test failure if it is unable to do so.
func generateRandomPayload(t *testing.T, n int) []byte {
@@ -7964,3 +8411,192 @@ func generateRandomPayload(t *testing.T, n int) []byte {
}
return buf
}
+
+func TestSendBufferTuning(t *testing.T) {
+ const maxPayload = 536
+ const mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload
+ const packetOverheadFactor = 2
+
+ testCases := []struct {
+ name string
+ autoTuningDisabled bool
+ }{
+ {"autoTuningDisabled", true},
+ {"autoTuningEnabled", false},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ // Set the stack option for send buffer size.
+ const defaultSndBufSz = maxPayload * tcp.InitialCwnd
+ const maxSndBufSz = defaultSndBufSz * 10
+ {
+ opt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: defaultSndBufSz, Max: maxSndBufSz}
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
+ }
+ }
+
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000, -1 /* epRcvBuf */)
+
+ oldSz := c.EP.SocketOptions().GetSendBufferSize()
+ if oldSz != defaultSndBufSz {
+ t.Fatalf("Wrong send buffer size got %d want %d", oldSz, defaultSndBufSz)
+ }
+
+ if tc.autoTuningDisabled {
+ c.EP.SocketOptions().SetSendBufferSize(defaultSndBufSz, true /* notify */)
+ }
+
+ data := make([]byte, maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ w, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&w, waiter.WritableEvents)
+ defer c.WQ.EventUnregister(&w)
+
+ bytesRead := 0
+ for {
+ // Packets will be sent till the send buffer
+ // size is reached.
+ var r bytes.Reader
+ r.Reset(data[bytesRead : bytesRead+maxPayload])
+ _, err := c.EP.Write(&r, tcpip.WriteOptions{})
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
+ break
+ }
+
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, 0)
+ bytesRead += maxPayload
+ data = append(data, data...)
+ }
+
+ // Send an ACK and wait for connection to become writable again.
+ c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead)
+ select {
+ case <-ch:
+ if err := c.EP.LastError(); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for connection")
+ }
+
+ outSz := int64(defaultSndBufSz)
+ if !tc.autoTuningDisabled {
+ // Calculate the new auto tuned send buffer.
+ var info tcpip.TCPInfoOption
+ if err := c.EP.GetSockOpt(&info); err != nil {
+ t.Fatalf("GetSockOpt failed: %v", err)
+ }
+ outSz = int64(info.SndCwnd) * packetOverheadFactor * maxPayload
+ }
+
+ if newSz := c.EP.SocketOptions().GetSendBufferSize(); newSz != outSz {
+ t.Fatalf("Wrong send buffer size, got %d want %d", newSz, outSz)
+ }
+ })
+ }
+}
+
+func TestTimestampSynCookies(t *testing.T) {
+ clock := faketime.NewManualClock()
+ tsNow := func() uint32 {
+ return uint32(clock.NowMonotonic().Sub(tcpip.MonotonicTime{}).Milliseconds())
+ }
+ // Advance the clock so that NowMonotonic is non-zero.
+ clock.Advance(time.Second)
+ c := context.NewWithOpts(t, context.Options{
+ EnableV4: true,
+ EnableV6: true,
+ MTU: defaultMTU,
+ Clock: clock,
+ })
+ defer c.Cleanup()
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
+ }
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ defer ep.Close()
+
+ tcpOpts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
+ header.EncodeTSOption(42, 0, tcpOpts[2:])
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+ iss := seqnum.Value(context.TestInitialSequenceNumber)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ RcvWnd: seqnum.Size(512),
+ SeqNum: iss,
+ TCPOpts: tcpOpts[:],
+ })
+ // Get the TSVal of SYN-ACK.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+ initialTSVal := tcpHdr.ParsedOptions().TSVal
+ // derive the tsOffset.
+ tsOffset := initialTSVal - tsNow()
+
+ header.EncodeTSOption(420, initialTSVal, tcpOpts[2:])
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ RcvWnd: seqnum.Size(512),
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ TCPOpts: tcpOpts[:],
+ })
+ c.EP, _, err = ep.Accept(nil)
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.ReadableEvents)
+ defer wq.EventUnregister(&we)
+ if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept(nil)
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ } else if err != nil {
+ t.Fatalf("failed to accept: %s", err)
+ }
+
+ // Advance the clock again so that we expect the next TSVal to change.
+ clock.Advance(time.Second)
+ data := []byte{1, 2, 3}
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // The endpoint should have a correct TSOffset so that the received TSVal
+ // should match our expectation.
+ if got, want := header.TCP(header.IPv4(c.GetPacket()).Payload()).ParsedOptions().TSVal, tsNow()+tsOffset; got != want {
+ t.Fatalf("got TSVal = %d, want %d", got, want)
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index 1deb1fe4d..65925daa5 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -32,7 +32,7 @@ import (
// createConnectedWithTimestampOption creates and connects c.ep with the
// timestamp option enabled.
func createConnectedWithTimestampOption(c *context.Context) *context.RawEndpoint {
- return c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, TSVal: 1})
+ return c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{TS: true, TSVal: 1})
}
// TestTimeStampEnabledConnect tests that netstack sends the timestamp option on
@@ -131,7 +131,7 @@ func TestTimeStampDisabledConnect(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnectedWithOptions(header.TCPSynOptions{})
+ c.CreateConnectedWithOptionsNoDelay(header.TCPSynOptions{})
}
func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) {
@@ -147,7 +147,7 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
tsVal := rand.Uint32()
- c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal})
+ c.AcceptWithOptionsNoDelay(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal})
// Now send some data and validate that timestamp is echoed correctly in the ACK.
data := []byte{1, 2, 3}
@@ -209,7 +209,7 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
}
t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
- c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ c.AcceptWithOptionsNoDelay(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
// Now send some data with the accepted connection endpoint and validate
// that no timestamp option is sent in the TCP segment.
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 96e4849d2..88bb99354 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -122,6 +122,9 @@ type Options struct {
// MTU indicates the maximum transmission unit on the link layer.
MTU uint32
+
+ // Clock that is used by Stack.
+ Clock tcpip.Clock
}
// Context provides an initialized Network stack and a link layer endpoint
@@ -182,6 +185,7 @@ func NewWithOpts(t *testing.T, opts Options) *Context {
stackOpts := stack.Options{
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ Clock: opts.Clock,
}
if opts.EnableV4 {
stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv4.NewProtocol)
@@ -239,8 +243,8 @@ func NewWithOpts(t *testing.T, opts Options) *Context {
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: StackAddrWithPrefix,
}
- if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err)
+ if err := s.AddProtocolAddress(1, v4ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v4ProtocolAddr, err)
}
routeTable = append(routeTable, tcpip.Route{
Destination: header.IPv4EmptySubnet,
@@ -253,8 +257,8 @@ func NewWithOpts(t *testing.T, opts Options) *Context {
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: StackV6AddrWithPrefix,
}
- if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err)
+ if err := s.AddProtocolAddress(1, v6ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v6ProtocolAddr, err)
}
routeTable = append(routeTable, tcpip.Route{
Destination: header.IPv6EmptySubnet,
@@ -879,13 +883,21 @@ func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) {
)
}
+// CreateConnectedWithOptionsNoDelay just calls CreateConnectedWithOptions
+// without delay.
+func (c *Context) CreateConnectedWithOptionsNoDelay(wantOptions header.TCPSynOptions) *RawEndpoint {
+ return c.CreateConnectedWithOptions(wantOptions, 0 /* delay */)
+}
+
// CreateConnectedWithOptions creates and connects c.ep with the specified TCP
// options enabled and returns a RawEndpoint which represents the other end of
-// the connection.
+// the connection. It delays before a SYNACK is sent. This makes c.EP have a
+// higher RTT estimate so that spurious TLPs aren't sent in tests, which helps
+// reduce flakiness.
//
// It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK
// does not carry an option that was not requested.
-func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint {
+func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint {
var err tcpip.Error
c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
@@ -911,18 +923,17 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
// TS value.
mss := uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize)
- checker.IPv4(c.t, b,
- checker.TCP(
- checker.DstPort(TestPort),
- checker.TCPFlags(header.TCPFlagSyn),
- checker.TCPSynOptions(header.TCPSynOptions{
- MSS: mss,
- TS: true,
- WS: int(c.WindowScale),
- SACKPermitted: c.SACKEnabled(),
- }),
- ),
+ synChecker := checker.TCP(
+ checker.DstPort(TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{
+ MSS: mss,
+ TS: true,
+ WS: int(c.WindowScale),
+ SACKPermitted: c.SACKEnabled(),
+ }),
)
+ checker.IPv4(c.t, b, synChecker)
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
}
@@ -948,6 +959,10 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
// Build SYN-ACK.
c.IRS = seqnum.Value(tcpSeg.SequenceNumber())
iss := seqnum.Value(TestInitialSequenceNumber)
+ if delay > 0 {
+ // Sleep so that RTT is increased.
+ time.Sleep(delay)
+ }
c.SendPacket(nil, &Headers{
SrcPort: tcpSeg.DestinationPort(),
DstPort: tcpSeg.SourcePort(),
@@ -959,7 +974,17 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
})
// Read ACK.
- ackPacket := c.GetPacket()
+ var ackPacket []byte
+ // Ignore retransimitted SYN packets.
+ for {
+ packet := c.GetPacket()
+ if header.TCP(header.IPv4(packet).Payload()).Flags()&header.TCPFlagSyn != 0 {
+ checker.IPv4(c.t, packet, synChecker)
+ } else {
+ ackPacket = packet
+ break
+ }
+ }
// Verify TCP header fields.
tcpCheckers := []checker.TransportChecker{
@@ -1016,13 +1041,19 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
}
}
-// AcceptWithOptions initializes a listening endpoint and connects to it with the
-// provided options enabled. It also verifies that the SYN-ACK has the expected
-// values for the provided options.
+// AcceptWithOptionsNoDelay delegates call to AcceptWithOptions without delay.
+func (c *Context) AcceptWithOptionsNoDelay(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+ return c.AcceptWithOptions(wndScale, synOptions, 0 /* delay */)
+}
+
+// AcceptWithOptions initializes a listening endpoint and connects to it with
+// the provided options enabled. It delays before the final ACK of the 3WHS is
+// sent. It also verifies that the SYN-ACK has the expected values for the
+// provided options.
//
// The function returns a RawEndpoint representing the other end of the accepted
// endpoint.
-func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint {
// Create EP and start listening.
wq := &waiter.Queue{}
ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
@@ -1045,7 +1076,7 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption
c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
}
- rep := c.PassiveConnectWithOptions(100, wndScale, synOptions)
+ rep := c.PassiveConnectWithOptions(100, wndScale, synOptions, delay)
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
@@ -1077,13 +1108,14 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption
// PassiveConnectWithOptions.
func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) {
synOptions.WS = -1
- c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions)
+ c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions, 0 /* delay */)
}
// PassiveConnectWithOptions initiates a new connection (with the specified TCP
// options enabled) to the port on which the Context.ep is listening for new
// connections. It also validates that the SYN-ACK has the expected values for
-// the enabled options.
+// the enabled options. The final ACK of the handshake is delayed by specified
+// duration.
//
// NOTE: MSS is not a negotiated option and it can be asymmetric
// in each direction. This function uses the maxPayload to set the MSS to be
@@ -1093,7 +1125,7 @@ func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCP
// wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the
// value of the window scaling option to be sent in the SYN. If synOptions.WS >
// 0 then we send the WindowScale option.
-func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions, delay time.Duration) *RawEndpoint {
c.t.Helper()
opts := make([]byte, header.TCPOptionsMaximumSize)
offset := 0
@@ -1180,7 +1212,10 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions
ackHeaders.TCPOpts = opts[:]
}
- // Send ACK.
+ // Send ACK, delay if needed.
+ if delay > 0 {
+ time.Sleep(delay)
+ }
c.SendPacket(nil, ackHeaders)
c.RcvdWindowScale = uint8(rcvdSynOptions.WS)
diff --git a/pkg/tcpip/transport/transport.go b/pkg/tcpip/transport/transport.go
new file mode 100644
index 000000000..4c2ae87f4
--- /dev/null
+++ b/pkg/tcpip/transport/transport.go
@@ -0,0 +1,16 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package transport supports transport protocols.
+package transport
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index cdc344ab7..d2c0963b0 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -35,6 +35,8 @@ go_library(
"//pkg/tcpip/header/parse",
"//pkg/tcpip/ports",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
+ "//pkg/tcpip/transport/internal/network",
"//pkg/tcpip/transport/raw",
"//pkg/waiter",
],
@@ -61,5 +63,6 @@ go_test(
"//pkg/tcpip/transport/icmp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
+ "@org_golang_x_time//rate:go_default_library",
],
)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 82a3f2287..39b1e08c0 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,8 +15,8 @@
package udp
import (
+ "fmt"
"io"
- "sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/sync"
@@ -25,12 +25,15 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
"gvisor.dev/gvisor/pkg/waiter"
)
// +stateify savable
type udpPacket struct {
udpPacketEntry
+ netProto tcpip.NetworkProtocolNumber
senderAddress tcpip.FullAddress
destinationAddress tcpip.FullAddress
packetInfo tcpip.IPPacketInfo
@@ -40,36 +43,6 @@ type udpPacket struct {
tos uint8
}
-// EndpointState represents the state of a UDP endpoint.
-type EndpointState tcpip.EndpointState
-
-// Endpoint states. Note that are represented in a netstack-specific manner and
-// may not be meaningful externally. Specifically, they need to be translated to
-// Linux's representation for these states if presented to userspace.
-const (
- _ EndpointState = iota
- StateInitial
- StateBound
- StateConnected
- StateClosed
-)
-
-// String implements fmt.Stringer.
-func (s EndpointState) String() string {
- switch s {
- case StateInitial:
- return "INITIAL"
- case StateBound:
- return "BOUND"
- case StateConnected:
- return "CONNECTING"
- case StateClosed:
- return "CLOSED"
- default:
- return "UNKNOWN"
- }
-}
-
// endpoint represents a UDP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -79,7 +52,6 @@ func (s EndpointState) String() string {
//
// +stateify savable
type endpoint struct {
- stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and do not
@@ -87,6 +59,9 @@ type endpoint struct {
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
uniqueID uint64
+ net network.Endpoint
+ stats tcpip.TransportEndpointStats
+ ops tcpip.SocketOptions
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -96,37 +71,19 @@ type endpoint struct {
rcvBufSize int
rcvClosed bool
- // The following fields are protected by the mu mutex.
- mu sync.RWMutex `state:"nosave"`
- // state must be read/set using the EndpointState()/setEndpointState()
- // methods.
- state uint32
- route *stack.Route `state:"manual"`
- dstPort uint16
- ttl uint8
- multicastTTL uint8
- multicastAddr tcpip.Address
- multicastNICID tcpip.NICID
- portFlags ports.Flags
-
lastErrorMu sync.Mutex `state:"nosave"`
lastError tcpip.Error
+ // The following fields are protected by the mu mutex.
+ mu sync.RWMutex `state:"nosave"`
+ portFlags ports.Flags
+
// Values used to reserve a port or register a transport endpoint.
// (which ever happens first).
boundBindToDevice tcpip.NICID
boundPortFlags ports.Flags
- // sendTOS represents IPv4 TOS or IPv6 TrafficClass,
- // applied while sending packets. Defaults to 0 as on Linux.
- sendTOS uint8
-
- // shutdownFlags represent the current shutdown state of the endpoint.
- shutdownFlags tcpip.ShutdownFlags
-
- // multicastMemberships that need to be remvoed when the endpoint is
- // closed. Protected by the mu mutex.
- multicastMemberships map[multicastMembership]struct{}
+ readShutdown bool
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
@@ -136,55 +93,25 @@ type endpoint struct {
// address).
effectiveNetProtos []tcpip.NetworkProtocolNumber
- // TODO(b/142022063): Add ability to save and restore per endpoint stats.
- stats tcpip.TransportEndpointStats `state:"nosave"`
-
- // owner is used to get uid and gid of the packet.
- owner tcpip.PacketOwner
-
- // ops is used to get socket level options.
- ops tcpip.SocketOptions
-
// frozen indicates if the packets should be delivered to the endpoint
// during restore.
frozen bool
-}
-// +stateify savable
-type multicastMembership struct {
- nicID tcpip.NICID
- multicastAddr tcpip.Address
+ localPort uint16
+ remotePort uint16
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
e := &endpoint{
- stack: s,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- TransProto: header.UDPProtocolNumber,
- },
+ stack: s,
waiterQueue: waiterQueue,
- // RFC 1075 section 5.4 recommends a TTL of 1 for membership
- // requests.
- //
- // RFC 5135 4.2.1 appears to assume that IGMP messages have a
- // TTL of 1.
- //
- // RFC 5135 Appendix A defines TTL=1: A multicast source that
- // wants its traffic to not traverse a router (e.g., leave a
- // home network) may find it useful to send traffic with IP
- // TTL=1.
- //
- // Linux defaults to TTL=1.
- multicastTTL: 1,
- multicastMemberships: make(map[multicastMembership]struct{}),
- state: uint32(StateInitial),
- uniqueID: s.UniqueID(),
+ uniqueID: s.UniqueID(),
}
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
e.ops.SetMulticastLoop(true)
e.ops.SetSendBufferSize(32*1024, false /* notify */)
e.ops.SetReceiveBufferSize(32*1024, false /* notify */)
+ e.net.Init(s, netProto, header.UDPProtocolNumber, &e.ops)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -200,20 +127,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
return e
}
-// setEndpointState updates the state of the endpoint to state atomically. This
-// method is unexported as the only place we should update the state is in this
-// package but we allow the state to be read freely without holding e.mu.
-//
-// Precondition: e.mu must be held to call this method.
-func (e *endpoint) setEndpointState(state EndpointState) {
- atomic.StoreUint32(&e.state, uint32(state))
-}
-
-// EndpointState() returns the current state of the endpoint.
-func (e *endpoint) EndpointState() EndpointState {
- return EndpointState(atomic.LoadUint32(&e.state))
-}
-
// UniqueID implements stack.TransportEndpoint.
func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
@@ -244,16 +157,22 @@ func (e *endpoint) Abort() {
// associated with it.
func (e *endpoint) Close() {
e.mu.Lock()
- e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
- switch e.EndpointState() {
- case StateBound, StateConnected:
- e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateClosed:
+ e.mu.Unlock()
+ return
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ id := e.net.Info().ID
+ id.LocalPort = e.localPort
+ id.RemotePort = e.remotePort
+ e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, id, e, e.boundPortFlags, e.boundBindToDevice)
portRes := ports.Reservation{
Networks: e.effectiveNetProtos,
Transport: ProtocolNumber,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: id.LocalAddress,
+ Port: id.LocalPort,
Flags: e.boundPortFlags,
BindToDevice: e.boundBindToDevice,
Dest: tcpip.FullAddress{},
@@ -261,13 +180,10 @@ func (e *endpoint) Close() {
e.stack.ReleasePort(portRes)
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
- for mem := range e.multicastMemberships {
- e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr)
- }
- e.multicastMemberships = make(map[multicastMembership]struct{})
-
// Close the receive list and drain it.
e.rcvMu.Lock()
e.rcvClosed = true
@@ -278,14 +194,9 @@ func (e *endpoint) Close() {
}
e.rcvMu.Unlock()
- if e.route != nil {
- e.route.Release()
- e.route = nil
- }
-
- // Update the state.
- e.setEndpointState(StateClosed)
-
+ e.net.Shutdown()
+ e.net.Close()
+ e.readShutdown = true
e.mu.Unlock()
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
@@ -322,21 +233,38 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// Control Messages
cm := tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: p.receivedAt.UnixNano(),
- }
- if e.ops.GetReceiveTOS() {
- cm.HasTOS = true
- cm.TOS = p.tos
- }
- if e.ops.GetReceiveTClass() {
- cm.HasTClass = true
- // Although TClass is an 8-bit value it's read in the CMsg as a uint32.
- cm.TClass = uint32(p.tos)
+ Timestamp: p.receivedAt,
}
- if e.ops.GetReceivePacketInfo() {
- cm.HasIPPacketInfo = true
- cm.PacketInfo = p.packetInfo
+
+ switch p.netProto {
+ case header.IPv4ProtocolNumber:
+ if e.ops.GetReceiveTOS() {
+ cm.HasTOS = true
+ cm.TOS = p.tos
+ }
+
+ if e.ops.GetReceivePacketInfo() {
+ cm.HasIPPacketInfo = true
+ cm.PacketInfo = p.packetInfo
+ }
+ case header.IPv6ProtocolNumber:
+ if e.ops.GetReceiveTClass() {
+ cm.HasTClass = true
+ // Although TClass is an 8-bit value it's read in the CMsg as a uint32.
+ cm.TClass = uint32(p.tos)
+ }
+
+ if e.ops.GetIPv6ReceivePacketInfo() {
+ cm.HasIPv6PacketInfo = true
+ cm.IPv6PacketInfo = tcpip.IPv6PacketInfo{
+ NIC: p.packetInfo.NIC,
+ Addr: p.packetInfo.DestinationAddr,
+ }
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized network protocol = %d", p.netProto))
}
+
if e.ops.GetReceiveOriginalDstAddress() {
cm.HasOriginalDstAddress = true
cm.OriginalDstAddress = p.destinationAddress
@@ -359,19 +287,19 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
return res, nil
}
-// prepareForWrite prepares the endpoint for sending data. In particular, it
-// binds it if it's still in the initial state. To do so, it must first
+// prepareForWriteInner prepares the endpoint for sending data. In particular,
+// it binds it if it's still in the initial state. To do so, it must first
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
// +checklocks:e.mu
-func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
- switch e.EndpointState() {
- case StateInitial:
- case StateConnected:
+func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
+ switch e.net.State() {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateConnected:
return false, nil
- case StateBound:
+ case transport.DatagramEndpointStateBound:
if to == nil {
return false, &tcpip.ErrDestinationRequired{}
}
@@ -386,7 +314,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
- if e.EndpointState() != StateInitial {
+ if e.net.State() != transport.DatagramEndpointStateInitial {
return true, nil
}
@@ -398,33 +326,6 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip
return true, nil
}
-// connectRoute establishes a route to the specified interface or the
-// configured multicast interface if no interface is specified and the
-// specified address is a multicast address.
-func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
- localAddr := e.ID.LocalAddress
- if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
- // A packet can only originate from a unicast address (i.e., an interface).
- localAddr = ""
- }
-
- if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
- if nicID == 0 {
- nicID = e.multicastNICID
- }
- if localAddr == "" && nicID == 0 {
- localAddr = e.multicastAddr
- }
- }
-
- // Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop())
- if err != nil {
- return nil, 0, err
- }
- return r, nicID, nil
-}
-
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
@@ -448,18 +349,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return n, err
}
-func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) {
+func (e *endpoint) prepareForWrite(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- // If we've shutdown with SHUT_WR we are in an invalid state for sending.
- if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
- return udpPacketInfo{}, &tcpip.ErrClosedForSend{}
- }
-
// Prepare for write.
for {
- retry, err := e.prepareForWrite(opts.To)
+ retry, err := e.prepareForWriteInner(opts.To)
if err != nil {
return udpPacketInfo{}, err
}
@@ -469,49 +365,28 @@ func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions
}
}
- route := e.route
- dstPort := e.dstPort
+ dst, connected := e.net.GetRemoteAddress()
+ dst.Port = e.remotePort
if opts.To != nil {
- // Reject destination address if it goes through a different
- // NIC than the endpoint was bound to.
- nicID := opts.To.NIC
- if nicID == 0 {
- nicID = tcpip.NICID(e.ops.GetBindToDevice())
- }
- if e.BindNICID != 0 {
- if nicID != 0 && nicID != e.BindNICID {
- return udpPacketInfo{}, &tcpip.ErrNoRoute{}
- }
-
- nicID = e.BindNICID
- }
-
if opts.To.Port == 0 {
// Port 0 is an invalid port to send to.
return udpPacketInfo{}, &tcpip.ErrInvalidEndpointState{}
}
- dst, netProto, err := e.checkV4MappedLocked(*opts.To)
- if err != nil {
- return udpPacketInfo{}, err
- }
-
- r, _, err := e.connectRoute(nicID, dst, netProto)
- if err != nil {
- return udpPacketInfo{}, err
- }
- defer r.Release()
-
- route = r
- dstPort = dst.Port
+ dst = *opts.To
+ } else if !connected {
+ return udpPacketInfo{}, &tcpip.ErrDestinationRequired{}
}
- if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
- return udpPacketInfo{}, &tcpip.ErrBroadcastDisabled{}
+ ctx, err := e.net.AcquireContextForWrite(opts)
+ if err != nil {
+ return udpPacketInfo{}, err
}
+ // TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
v := make([]byte, p.Len())
if _, err := io.ReadFull(p, v); err != nil {
+ ctx.Release()
return udpPacketInfo{}, &tcpip.ErrBadBuffer{}
}
if len(v) > header.UDPMaximumPacketSize {
@@ -520,50 +395,25 @@ func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions
if so.GetRecvError() {
so.QueueLocalErr(
&tcpip.ErrMessageTooLong{},
- route.NetProto(),
+ e.net.NetProto(),
header.UDPMaximumPacketSize,
- tcpip.FullAddress{
- NIC: route.NICID(),
- Addr: route.RemoteAddress(),
- Port: dstPort,
- },
+ dst,
v,
)
}
+ ctx.Release()
return udpPacketInfo{}, &tcpip.ErrMessageTooLong{}
}
- ttl := e.ttl
- useDefaultTTL := ttl == 0
- if header.IsV4MulticastAddress(route.RemoteAddress()) || header.IsV6MulticastAddress(route.RemoteAddress()) {
- ttl = e.multicastTTL
- // Multicast allows a 0 TTL.
- useDefaultTTL = false
- }
-
return udpPacketInfo{
- route: route,
- data: buffer.View(v),
- localPort: e.ID.LocalPort,
- remotePort: dstPort,
- ttl: ttl,
- useDefaultTTL: useDefaultTTL,
- tos: e.sendTOS,
- owner: e.owner,
- noChecksum: e.SocketOptions().GetNoChecksum(),
+ ctx: ctx,
+ data: v,
+ localPort: e.localPort,
+ remotePort: dst.Port,
}, nil
}
func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
- if err := e.LastError(); err != nil {
- return 0, err
- }
-
- // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
- if opts.More {
- return 0, &tcpip.ErrInvalidOptionValue{}
- }
-
// Do not hold lock when sending as loopback is synchronous and if the UDP
// datagram ends up generating an ICMP response then it can result in a
// deadlock where the ICMP response handling ends up acquiring this endpoint's
@@ -574,15 +424,53 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
//
// See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read
// locking is prohibited.
- u, err := e.buildUDPPacketInfo(p, opts)
- if err != nil {
+
+ if err := e.LastError(); err != nil {
return 0, err
}
- n, err := u.send()
+
+ udpInfo, err := e.prepareForWrite(p, opts)
if err != nil {
return 0, err
}
- return int64(n), nil
+ defer udpInfo.ctx.Release()
+
+ pktInfo := udpInfo.ctx.PacketInfo()
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(pktInfo.MaxHeaderLength),
+ Data: udpInfo.data.ToVectorisedView(),
+ })
+
+ // Initialize the UDP header.
+ udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+ pkt.TransportProtocolNumber = ProtocolNumber
+
+ length := uint16(pkt.Size())
+ udp.Encode(&header.UDPFields{
+ SrcPort: udpInfo.localPort,
+ DstPort: udpInfo.remotePort,
+ Length: length,
+ })
+
+ // Set the checksum field unless TX checksum offload is enabled.
+ // On IPv4, UDP checksum is optional, and a zero value indicates the
+ // transmitter skipped the checksum generation (RFC768).
+ // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
+ if pktInfo.RequiresTXTransportChecksum &&
+ (!e.ops.GetNoChecksum() || pktInfo.NetProto == header.IPv6ProtocolNumber) {
+ udp.SetChecksum(^udp.CalculateChecksum(header.ChecksumCombine(
+ header.PseudoHeaderChecksum(ProtocolNumber, pktInfo.LocalAddress, pktInfo.RemoteAddress, length),
+ pkt.Data().AsRange().Checksum(),
+ )))
+ }
+ if err := udpInfo.ctx.WritePacket(pkt, false /* headerIncluded */); err != nil {
+ e.stack.Stats().UDP.PacketSendErrors.Increment()
+ return 0, err
+ }
+
+ // Track count of packets sent.
+ e.stack.Stats().UDP.PacketsSent.Increment()
+ return int64(len(udpInfo.data)), nil
}
// OnReuseAddressSet implements tcpip.SocketOptionsHandler.
@@ -601,36 +489,7 @@ func (e *endpoint) OnReusePortSet(v bool) {
// SetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
- switch opt {
- case tcpip.MTUDiscoverOption:
- // Return not supported if the value is not disabling path
- // MTU discovery.
- if v != tcpip.PMTUDiscoveryDont {
- return &tcpip.ErrNotSupported{}
- }
-
- case tcpip.MulticastTTLOption:
- e.mu.Lock()
- e.multicastTTL = uint8(v)
- e.mu.Unlock()
-
- case tcpip.TTLOption:
- e.mu.Lock()
- e.ttl = uint8(v)
- e.mu.Unlock()
-
- case tcpip.IPv4TOSOption:
- e.mu.Lock()
- e.sendTOS = uint8(v)
- e.mu.Unlock()
-
- case tcpip.IPv6TrafficClassOption:
- e.mu.Lock()
- e.sendTOS = uint8(v)
- e.mu.Unlock()
- }
-
- return nil
+ return e.net.SetSockOptInt(opt, v)
}
var _ tcpip.SocketOptionsHandler = (*endpoint)(nil)
@@ -642,145 +501,12 @@ func (e *endpoint) HasNIC(id int32) bool {
// SetSockOpt implements tcpip.Endpoint.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
- switch v := opt.(type) {
- case *tcpip.MulticastInterfaceOption:
- e.mu.Lock()
- defer e.mu.Unlock()
-
- fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
- fa, netProto, err := e.checkV4MappedLocked(fa)
- if err != nil {
- return err
- }
- nic := v.NIC
- addr := fa.Addr
-
- if nic == 0 && addr == "" {
- e.multicastAddr = ""
- e.multicastNICID = 0
- break
- }
-
- if nic != 0 {
- if !e.stack.CheckNIC(nic) {
- return &tcpip.ErrBadLocalAddress{}
- }
- } else {
- nic = e.stack.CheckLocalAddress(0, netProto, addr)
- if nic == 0 {
- return &tcpip.ErrBadLocalAddress{}
- }
- }
-
- if e.BindNICID != 0 && e.BindNICID != nic {
- return &tcpip.ErrInvalidEndpointState{}
- }
-
- e.multicastNICID = nic
- e.multicastAddr = addr
-
- case *tcpip.AddMembershipOption:
- if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
- return &tcpip.ErrInvalidOptionValue{}
- }
-
- nicID := v.NIC
-
- if v.InterfaceAddr.Unspecified() {
- if nicID == 0 {
- if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil {
- nicID = r.NICID()
- r.Release()
- }
- }
- } else {
- nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
- }
- if nicID == 0 {
- return &tcpip.ErrUnknownDevice{}
- }
-
- memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
-
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if _, ok := e.multicastMemberships[memToInsert]; ok {
- return &tcpip.ErrPortInUse{}
- }
-
- if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
- return err
- }
-
- e.multicastMemberships[memToInsert] = struct{}{}
-
- case *tcpip.RemoveMembershipOption:
- if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
- return &tcpip.ErrInvalidOptionValue{}
- }
-
- nicID := v.NIC
- if v.InterfaceAddr.Unspecified() {
- if nicID == 0 {
- if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil {
- nicID = r.NICID()
- r.Release()
- }
- }
- } else {
- nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
- }
- if nicID == 0 {
- return &tcpip.ErrUnknownDevice{}
- }
-
- memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
-
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if _, ok := e.multicastMemberships[memToRemove]; !ok {
- return &tcpip.ErrBadLocalAddress{}
- }
-
- if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
- return err
- }
-
- delete(e.multicastMemberships, memToRemove)
-
- case *tcpip.SocketDetachFilterOption:
- return nil
- }
- return nil
+ return e.net.SetSockOpt(opt)
}
// GetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
- case tcpip.IPv4TOSOption:
- e.mu.RLock()
- v := int(e.sendTOS)
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.IPv6TrafficClassOption:
- e.mu.RLock()
- v := int(e.sendTOS)
- e.mu.RUnlock()
- return v, nil
-
- case tcpip.MTUDiscoverOption:
- // The only supported setting is path MTU discovery disabled.
- return tcpip.PMTUDiscoveryDont, nil
-
- case tcpip.MulticastTTLOption:
- e.mu.Lock()
- v := int(e.multicastTTL)
- e.mu.Unlock()
- return v, nil
-
case tcpip.ReceiveQueueSizeOption:
v := 0
e.rcvMu.Lock()
@@ -791,108 +517,22 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- case tcpip.TTLOption:
- e.mu.Lock()
- v := int(e.ttl)
- e.mu.Unlock()
- return v, nil
-
default:
- return -1, &tcpip.ErrUnknownProtocolOption{}
+ return e.net.GetSockOptInt(opt)
}
}
// GetSockOpt implements tcpip.Endpoint.
func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.MulticastInterfaceOption:
- e.mu.Lock()
- *o = tcpip.MulticastInterfaceOption{
- NIC: e.multicastNICID,
- InterfaceAddr: e.multicastAddr,
- }
- e.mu.Unlock()
-
- default:
- return &tcpip.ErrUnknownProtocolOption{}
- }
- return nil
+ return e.net.GetSockOpt(opt)
}
-// udpPacketInfo contains all information required to send a UDP packet.
-//
-// This should be used as a value-only type, which exists in order to simplify
-// return value syntax. It should not be exported or extended.
+// udpPacketInfo holds information needed to send a UDP packet.
type udpPacketInfo struct {
- route *stack.Route
- data buffer.View
- localPort uint16
- remotePort uint16
- ttl uint8
- useDefaultTTL bool
- tos uint8
- owner tcpip.PacketOwner
- noChecksum bool
-}
-
-// send sends the given packet.
-func (u *udpPacketInfo) send() (int, tcpip.Error) {
- vv := u.data.ToVectorisedView()
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.UDPMinimumSize + int(u.route.MaxHeaderLength()),
- Data: vv,
- })
- pkt.Owner = u.owner
-
- // Initialize the UDP header.
- udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
- pkt.TransportProtocolNumber = ProtocolNumber
-
- length := uint16(pkt.Size())
- udp.Encode(&header.UDPFields{
- SrcPort: u.localPort,
- DstPort: u.remotePort,
- Length: length,
- })
-
- // Set the checksum field unless TX checksum offload is enabled.
- // On IPv4, UDP checksum is optional, and a zero value indicates the
- // transmitter skipped the checksum generation (RFC768).
- // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
- if u.route.RequiresTXTransportChecksum() &&
- (!u.noChecksum || u.route.NetProto() == header.IPv6ProtocolNumber) {
- xsum := u.route.PseudoHeaderChecksum(ProtocolNumber, length)
- for _, v := range vv.Views() {
- xsum = header.Checksum(v, xsum)
- }
- udp.SetChecksum(^udp.CalculateChecksum(xsum))
- }
-
- if u.useDefaultTTL {
- u.ttl = u.route.DefaultTTL()
- }
- if err := u.route.WritePacket(stack.NetworkHeaderParams{
- Protocol: ProtocolNumber,
- TTL: u.ttl,
- TOS: u.tos,
- }, pkt); err != nil {
- u.route.Stats().UDP.PacketSendErrors.Increment()
- return 0, err
- }
-
- // Track count of packets sent.
- u.route.Stats().UDP.PacketsSent.Increment()
- return len(u.data), nil
-}
-
-// checkV4MappedLocked determines the effective network protocol and converts
-// addr to its canonical form.
-func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only())
- if err != nil {
- return tcpip.FullAddress{}, 0, err
- }
- return unwrapped, netProto, nil
+ ctx network.WriteContext
+ data buffer.View
+ localPort uint16
+ remotePort uint16
}
// Disconnect implements tcpip.Endpoint.
@@ -900,7 +540,7 @@ func (e *endpoint) Disconnect() tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if e.EndpointState() != StateConnected {
+ if e.net.State() != transport.DatagramEndpointStateConnected {
return nil
}
var (
@@ -913,26 +553,28 @@ func (e *endpoint) Disconnect() tcpip.Error {
boundPortFlags := e.boundPortFlags
// Exclude ephemerally bound endpoints.
- if e.BindNICID != 0 || e.ID.LocalAddress == "" {
+ info := e.net.Info()
+ info.ID.LocalPort = e.localPort
+ info.ID.RemotePort = e.remotePort
+ if e.net.WasBound() {
var err tcpip.Error
id = stack.TransportEndpointID{
- LocalPort: e.ID.LocalPort,
- LocalAddress: e.ID.LocalAddress,
+ LocalPort: info.ID.LocalPort,
+ LocalAddress: info.ID.LocalAddress,
}
id, btd, err = e.registerWithStack(e.effectiveNetProtos, id)
if err != nil {
return err
}
- e.setEndpointState(StateBound)
boundPortFlags = e.boundPortFlags
} else {
- if e.ID.LocalPort != 0 {
+ if info.ID.LocalPort != 0 {
// Release the ephemeral port.
portRes := ports.Reservation{
Networks: e.effectiveNetProtos,
Transport: ProtocolNumber,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: info.ID.LocalAddress,
+ Port: info.ID.LocalPort,
Flags: boundPortFlags,
BindToDevice: e.boundBindToDevice,
Dest: tcpip.FullAddress{},
@@ -940,15 +582,14 @@ func (e *endpoint) Disconnect() tcpip.Error {
e.stack.ReleasePort(portRes)
e.boundPortFlags = ports.Flags{}
}
- e.setEndpointState(StateInitial)
}
- e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice)
- e.ID = id
+ e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, info.ID, e, boundPortFlags, e.boundBindToDevice)
e.boundBindToDevice = btd
- e.route.Release()
- e.route = nil
- e.dstPort = 0
+ e.localPort = id.LocalPort
+ e.remotePort = id.RemotePort
+
+ e.net.Disconnect()
return nil
}
@@ -958,88 +599,48 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- nicID := addr.NIC
- var localPort uint16
- switch e.EndpointState() {
- case StateInitial:
- case StateBound, StateConnected:
- localPort = e.ID.LocalPort
- if e.BindNICID == 0 {
- break
- }
-
- if nicID != 0 && nicID != e.BindNICID {
- return &tcpip.ErrInvalidEndpointState{}
+ err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error {
+ nextID.LocalPort = e.localPort
+ nextID.RemotePort = addr.Port
+
+ // Even if we're connected, this endpoint can still be used to send
+ // packets on a different network protocol, so we register both even if
+ // v6only is set to false and this is an ipv6 endpoint.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv4ProtocolNumber,
+ header.IPv6ProtocolNumber,
+ }
}
- nicID = e.BindNICID
- default:
- return &tcpip.ErrInvalidEndpointState{}
- }
+ oldPortFlags := e.boundPortFlags
- addr, netProto, err := e.checkV4MappedLocked(addr)
- if err != nil {
- return err
- }
-
- r, nicID, err := e.connectRoute(nicID, addr, netProto)
- if err != nil {
- return err
- }
-
- id := stack.TransportEndpointID{
- LocalAddress: e.ID.LocalAddress,
- LocalPort: localPort,
- RemotePort: addr.Port,
- RemoteAddress: r.RemoteAddress(),
- }
-
- if e.EndpointState() == StateInitial {
- id.LocalAddress = r.LocalAddress()
- }
-
- // Even if we're connected, this endpoint can still be used to send
- // packets on a different network protocol, so we register both even if
- // v6only is set to false and this is an ipv6 endpoint.
- netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() {
- netProtos = []tcpip.NetworkProtocolNumber{
- header.IPv4ProtocolNumber,
- header.IPv6ProtocolNumber,
+ nextID, btd, err := e.registerWithStack(netProtos, nextID)
+ if err != nil {
+ return err
}
- }
- oldPortFlags := e.boundPortFlags
+ // Remove the old registration.
+ if e.localPort != 0 {
+ previousID.LocalPort = e.localPort
+ previousID.RemotePort = e.remotePort
+ e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, previousID, e, oldPortFlags, e.boundBindToDevice)
+ }
- id, btd, err := e.registerWithStack(netProtos, id)
+ e.localPort = nextID.LocalPort
+ e.remotePort = nextID.RemotePort
+ e.boundBindToDevice = btd
+ e.effectiveNetProtos = netProtos
+ return nil
+ })
if err != nil {
- r.Release()
return err
}
- // Remove the old registration.
- if e.ID.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice)
- }
-
- e.ID = id
- e.boundBindToDevice = btd
- if e.route != nil {
- // If the endpoint was already connected then make sure we release the
- // previous route.
- e.route.Release()
- }
- e.route = r
- e.dstPort = addr.Port
- e.RegisterNICID = nicID
- e.effectiveNetProtos = netProtos
-
- e.setEndpointState(StateConnected)
-
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
-
return nil
}
@@ -1054,15 +655,23 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- // A socket in the bound state can still receive multicast messages,
- // so we need to notify waiters on shutdown.
- if state := e.EndpointState(); state != StateBound && state != StateConnected {
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
return &tcpip.ErrNotConnected{}
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
- e.shutdownFlags |= flags
+ if flags&tcpip.ShutdownWrite != 0 {
+ if err := e.net.Shutdown(); err != nil {
+ return err
+ }
+ }
if flags&tcpip.ShutdownRead != 0 {
+ e.readShutdown = true
+
e.rcvMu.Lock()
wasClosed := e.rcvClosed
e.rcvClosed = true
@@ -1088,7 +697,7 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, tcpip.Error) {
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
- if e.ID.LocalPort == 0 {
+ if e.localPort == 0 {
portRes := ports.Reservation{
Networks: netProtos,
Transport: ProtocolNumber,
@@ -1126,56 +735,43 @@ func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id
func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.EndpointState() != StateInitial {
+ if e.net.State() != transport.DatagramEndpointStateInitial {
return &tcpip.ErrInvalidEndpointState{}
}
- addr, netProto, err := e.checkV4MappedLocked(addr)
- if err != nil {
- return err
- }
-
- // Expand netProtos to include v4 and v6 if the caller is binding to a
- // wildcard (empty) address, and this is an IPv6 endpoint with v6only
- // set to false.
- netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && addr.Addr == "" {
- netProtos = []tcpip.NetworkProtocolNumber{
- header.IPv6ProtocolNumber,
- header.IPv4ProtocolNumber,
+ err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error {
+ // Expand netProtos to include v4 and v6 if the caller is binding to a
+ // wildcard (empty) address, and this is an IPv6 endpoint with v6only
+ // set to false.
+ netProtos := []tcpip.NetworkProtocolNumber{boundNetProto}
+ if boundNetProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && boundAddr == "" {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv6ProtocolNumber,
+ header.IPv4ProtocolNumber,
+ }
}
- }
- nicID := addr.NIC
- if len(addr.Addr) != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) {
- // A local unicast address was specified, verify that it's valid.
- nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
- if nicID == 0 {
- return &tcpip.ErrBadLocalAddress{}
+ id := stack.TransportEndpointID{
+ LocalPort: addr.Port,
+ LocalAddress: boundAddr,
+ }
+ id, btd, err := e.registerWithStack(netProtos, id)
+ if err != nil {
+ return err
}
- }
- id := stack.TransportEndpointID{
- LocalPort: addr.Port,
- LocalAddress: addr.Addr,
- }
- id, btd, err := e.registerWithStack(netProtos, id)
+ e.localPort = id.LocalPort
+ e.boundBindToDevice = btd
+ e.effectiveNetProtos = netProtos
+ return nil
+ })
if err != nil {
return err
}
- e.ID = id
- e.boundBindToDevice = btd
- e.RegisterNICID = nicID
- e.effectiveNetProtos = netProtos
-
- // Mark endpoint as bound.
- e.setEndpointState(StateBound)
-
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
-
return nil
}
@@ -1190,9 +786,6 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
return err
}
- // Save the effective NICID generated by bindLocked.
- e.BindNICID = e.RegisterNICID
-
return nil
}
@@ -1201,16 +794,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- addr := e.ID.LocalAddress
- if e.EndpointState() == StateConnected {
- addr = e.route.LocalAddress()
- }
-
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: addr,
- Port: e.ID.LocalPort,
- }, nil
+ addr := e.net.GetLocalAddress()
+ addr.Port = e.localPort
+ return addr, nil
}
// GetRemoteAddress returns the address to which the endpoint is connected.
@@ -1218,15 +804,13 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.EndpointState() != StateConnected || e.dstPort == 0 {
+ addr, connected := e.net.GetRemoteAddress()
+ if !connected || e.remotePort == 0 {
return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
}
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: e.ID.RemoteAddress,
- Port: e.ID.RemotePort,
- }, nil
+ addr.Port = e.remotePort
+ return addr, nil
}
// Readiness returns the current readiness of the endpoint. For example, if
@@ -1321,6 +905,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
// Push new packet into receive list and increment the buffer size.
packet := &udpPacket{
+ netProto: pkt.NetworkProtocolNumber,
senderAddress: tcpip.FullAddress{
NIC: pkt.NICID,
Addr: id.RemoteAddress,
@@ -1376,19 +961,20 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p
payload = udp.Payload()
}
+ id := e.net.Info().ID
e.SocketOptions().QueueErr(&tcpip.SockError{
Err: err,
Cause: transErr,
Payload: payload,
Dst: tcpip.FullAddress{
NIC: pkt.NICID,
- Addr: e.ID.RemoteAddress,
- Port: e.ID.RemotePort,
+ Addr: id.RemoteAddress,
+ Port: e.remotePort,
},
Offender: tcpip.FullAddress{
NIC: pkt.NICID,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
+ Addr: id.LocalAddress,
+ Port: e.localPort,
},
NetProto: pkt.NetworkProtocolNumber,
})
@@ -1403,7 +989,7 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB
// TODO(gvisor.dev/issues/5270): Handle all transport errors.
switch transErr.Kind() {
case stack.DestinationPortUnreachableTransportError:
- if e.EndpointState() == StateConnected {
+ if e.net.State() == transport.DatagramEndpointStateConnected {
e.onICMPError(&tcpip.ErrConnectionRefused{}, transErr, pkt)
}
}
@@ -1411,16 +997,17 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB
// State implements tcpip.Endpoint.
func (e *endpoint) State() uint32 {
- return uint32(e.EndpointState())
+ return uint32(e.net.State())
}
// Info returns a copy of the endpoint info.
func (e *endpoint) Info() tcpip.EndpointInfo {
e.mu.RLock()
- // Make a copy of the endpoint info.
- ret := e.TransportEndpointInfo
- e.mu.RUnlock()
- return &ret
+ defer e.mu.RUnlock()
+ info := e.net.Info()
+ info.ID.LocalPort = e.localPort
+ info.ID.RemotePort = e.remotePort
+ return &info
}
// Stats returns a pointer to the endpoint stats.
@@ -1431,13 +1018,9 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
// Wait implements tcpip.Endpoint.
func (*endpoint) Wait() {}
-func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
- return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr)
-}
-
// SetOwner implements tcpip.Endpoint.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
- e.owner = owner
+ e.net.SetOwner(owner)
}
// SocketOptions implements tcpip.Endpoint.
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 1f638c3f6..2ff8b0482 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -15,12 +15,13 @@
package udp
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
)
// saveReceivedAt is invoked by stateify.
@@ -35,17 +36,11 @@ func (p *udpPacket) loadReceivedAt(nsec int64) {
// saveData saves udpPacket.data field.
func (p *udpPacket) saveData() buffer.VectorisedView {
- // We cannot save p.data directly as p.data.views may alias to p.views,
- // which is not allowed by state framework (in-struct pointer).
return p.data.Clone(nil)
}
// loadData loads udpPacket.data field.
func (p *udpPacket) loadData(data buffer.VectorisedView) {
- // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
- // here because data.views is not guaranteed to be loaded by now. Plus,
- // data.views will be allocated anyway so there really is little point
- // of utilizing p.views for data.views.
p.data = data
}
@@ -66,50 +61,28 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.mu.Lock()
defer e.mu.Unlock()
+ e.net.Resume(s)
+
e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
- for m := range e.multicastMemberships {
- if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil {
- panic(err)
- }
- }
-
- state := e.EndpointState()
- if state != StateBound && state != StateConnected {
- return
- }
-
- netProto := e.effectiveNetProtos[0]
- // Connect() and bindLocked() both assert
- //
- // netProto == header.IPv6ProtocolNumber
- //
- // before creating a multi-entry effectiveNetProtos.
- if len(e.effectiveNetProtos) > 1 {
- netProto = header.IPv6ProtocolNumber
- }
-
- var err tcpip.Error
- if state == StateConnected {
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop())
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ // Our saved state had a port, but we don't actually have a
+ // reservation. We need to remove the port from our state, but still
+ // pass it to the reservation machinery.
+ var err tcpip.Error
+ id := e.net.Info().ID
+ id.LocalPort = e.localPort
+ id.RemotePort = e.remotePort
+ id, e.boundBindToDevice, err = e.registerWithStack(e.effectiveNetProtos, id)
if err != nil {
panic(err)
}
- } else if len(e.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.RegisterNICID, netProto, e.ID.LocalAddress) { // stateBound
- // A local unicast address is specified, verify that it's valid.
- if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 {
- panic(&tcpip.ErrBadLocalAddress{})
- }
- }
-
- // Our saved state had a port, but we don't actually have a
- // reservation. We need to remove the port from our state, but still
- // pass it to the reservation machinery.
- id := e.ID
- e.ID.LocalPort = 0
- e.ID, e.boundBindToDevice, err = e.registerWithStack(e.effectiveNetProtos, id)
- if err != nil {
- panic(err)
+ e.localPort = id.LocalPort
+ e.remotePort = id.RemotePort
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index 7c357cb09..7238fc019 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -70,28 +70,29 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID {
// CreateEndpoint creates a connected UDP endpoint for the session request.
func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
+ ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue)
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
netHdr := r.pkt.Network()
- route, err := r.stack.FindRoute(r.pkt.NICID, netHdr.DestinationAddress(), netHdr.SourceAddress(), r.pkt.NetworkProtocolNumber, false /* multicastLoop */)
- if err != nil {
+ if err := ep.net.Bind(tcpip.FullAddress{NIC: r.pkt.NICID, Addr: netHdr.DestinationAddress(), Port: r.id.LocalPort}); err != nil {
+ return nil, err
+ }
+
+ if err := ep.net.Connect(tcpip.FullAddress{NIC: r.pkt.NICID, Addr: netHdr.SourceAddress(), Port: r.id.RemotePort}); err != nil {
return nil, err
}
- ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue)
if err := r.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil {
ep.Close()
- route.Release()
return nil, err
}
- ep.ID = r.id
- ep.route = route
- ep.dstPort = r.id.RemotePort
+ ep.localPort = r.id.LocalPort
+ ep.remotePort = r.id.RemotePort
ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}
- ep.RegisterNICID = r.pkt.NICID
ep.boundPortFlags = ep.portFlags
- ep.state = uint32(StateConnected)
-
ep.rcvMu.Lock()
ep.rcvReady = true
ep.rcvMu.Unlock()
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 4008cacf2..b3199489c 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -22,6 +22,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
+ "golang.org/x/time/rate"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -290,6 +291,7 @@ type testContext struct {
t *testing.T
linkEP *channel.Endpoint
s *stack.Stack
+ nicID tcpip.NICID
ep tcpip.Endpoint
wq waiter.Queue
@@ -301,6 +303,8 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext {
}
func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal bool) *testContext {
+ const nicID = 1
+
t.Helper()
options := stack.Options{
@@ -310,38 +314,50 @@ func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal boo
Clock: &faketime.NullClock{},
}
s := stack.New(options)
+ // Disable ICMP rate limiter because we're using Null clock, which never advances time and thus
+ // never allows ICMP messages.
+ s.SetICMPLimit(rate.Inf)
ep := channel.New(256, mtu, "")
wep := stack.LinkEndpoint(ep)
if testing.Verbose() {
wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, wep); err != nil {
- t.Fatalf("CreateNIC failed: %s", err)
+ if err := s.CreateNIC(nicID, wep); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress failed: %s", err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(stackAddr).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err)
}
- if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil {
- t.Fatalf("AddAddress failed: %s", err)
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(stackV6Addr).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err)
}
s.SetRouteTable([]tcpip.Route{
{
Destination: header.IPv4EmptySubnet,
- NIC: 1,
+ NIC: nicID,
},
{
Destination: header.IPv6EmptySubnet,
- NIC: 1,
+ NIC: nicID,
},
})
return &testContext{
t: t,
s: s,
+ nicID: nicID,
linkEP: ep,
}
}
@@ -1353,64 +1369,70 @@ func TestReadIncrementsPacketsReceived(t *testing.T) {
func TestReadIPPacketInfo(t *testing.T) {
tests := []struct {
- name string
- proto tcpip.NetworkProtocolNumber
- flow testFlow
- expectedLocalAddr tcpip.Address
- expectedDestAddr tcpip.Address
+ name string
+ proto tcpip.NetworkProtocolNumber
+ flow testFlow
+ checker func(tcpip.NICID) checker.ControlMessagesChecker
}{
{
- name: "IPv4 unicast",
- proto: header.IPv4ProtocolNumber,
- flow: unicastV4,
- expectedLocalAddr: stackAddr,
- expectedDestAddr: stackAddr,
+ name: "IPv4 unicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: unicastV4,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
+ NIC: id,
+ LocalAddr: stackAddr,
+ DestinationAddr: stackAddr,
+ })
+ },
},
{
name: "IPv4 multicast",
proto: header.IPv4ProtocolNumber,
flow: multicastV4,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedLocalAddr: multicastAddr,
- expectedDestAddr: multicastAddr,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
+ NIC: id,
+ // TODO(gvisor.dev/issue/3556): Check for a unicast address.
+ LocalAddr: multicastAddr,
+ DestinationAddr: multicastAddr,
+ })
+ },
},
{
name: "IPv4 broadcast",
proto: header.IPv4ProtocolNumber,
flow: broadcast,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedLocalAddr: broadcastAddr,
- expectedDestAddr: broadcastAddr,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
+ NIC: id,
+ // TODO(gvisor.dev/issue/3556): Check for a unicast address.
+ LocalAddr: broadcastAddr,
+ DestinationAddr: broadcastAddr,
+ })
+ },
},
{
- name: "IPv6 unicast",
- proto: header.IPv6ProtocolNumber,
- flow: unicastV6,
- expectedLocalAddr: stackV6Addr,
- expectedDestAddr: stackV6Addr,
+ name: "IPv6 unicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: unicastV6,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{
+ NIC: id,
+ Addr: stackV6Addr,
+ })
+ },
},
{
name: "IPv6 multicast",
proto: header.IPv6ProtocolNumber,
flow: multicastV6,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedLocalAddr: multicastV6Addr,
- expectedDestAddr: multicastV6Addr,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{
+ NIC: id,
+ Addr: multicastV6Addr,
+ })
+ },
},
}
@@ -1433,13 +1455,16 @@ func TestReadIPPacketInfo(t *testing.T) {
}
}
- c.ep.SocketOptions().SetReceivePacketInfo(true)
+ switch f := test.flow.netProto(); f {
+ case header.IPv4ProtocolNumber:
+ c.ep.SocketOptions().SetReceivePacketInfo(true)
+ case header.IPv6ProtocolNumber:
+ c.ep.SocketOptions().SetIPv6ReceivePacketInfo(true)
+ default:
+ t.Fatalf("unhandled protocol number = %d", f)
+ }
- testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
- NIC: 1,
- LocalAddr: test.expectedLocalAddr,
- DestinationAddr: test.expectedDestAddr,
- }))
+ testRead(c, test.flow, test.checker(c.nicID))
if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 {
t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)
@@ -1644,8 +1669,10 @@ func TestSetTTL(t *testing.T) {
}
}
+var v4PacketFlows = [...]testFlow{unicastV4, multicastV4, broadcast, unicastV4in6, multicastV4in6, broadcastIn6}
+
func TestSetTOS(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
+ for _, flow := range v4PacketFlows {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1680,8 +1707,10 @@ func TestSetTOS(t *testing.T) {
}
}
+var v6PacketFlows = [...]testFlow{unicastV6, unicastV6Only, multicastV6}
+
func TestSetTClass(t *testing.T) {
- for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} {
+ for _, flow := range v6PacketFlows {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1725,8 +1754,14 @@ func TestReceiveTosTClass(t *testing.T) {
name string
tests []testFlow
}{
- {RcvTOSOpt, []testFlow{unicastV4, broadcast}},
- {RcvTClassOpt, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}},
+ {
+ name: RcvTOSOpt,
+ tests: v4PacketFlows[:],
+ },
+ {
+ name: RcvTClassOpt,
+ tests: v6PacketFlows[:],
+ },
}
for _, testCase := range testCases {
for _, flow := range testCase.tests {
@@ -1737,6 +1772,14 @@ func TestReceiveTosTClass(t *testing.T) {
c.createEndpointForFlow(flow)
name := testCase.name
+ if flow.isMulticast() {
+ netProto := flow.netProto()
+ addr := flow.getMcastAddr()
+ if err := c.s.JoinGroup(netProto, c.nicID, addr); err != nil {
+ c.t.Fatalf("JoinGroup(%d, %d, %s): %s", netProto, c.nicID, addr, err)
+ }
+ }
+
var optionGetter func() bool
var optionSetter func(bool)
switch name {
@@ -2482,8 +2525,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
if err := s.CreateNIC(nicID1, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
+ if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err)
}
s.SetRouteTable(test.routes)
diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD
index 234125c38..8902be2d3 100644
--- a/pkg/unet/BUILD
+++ b/pkg/unet/BUILD
@@ -10,6 +10,7 @@ go_library(
],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/eventfd",
"//pkg/sync",
"@org_golang_x_sys//unix:go_default_library",
],
diff --git a/pkg/unet/unet.go b/pkg/unet/unet.go
index 40fa72925..0dc0c37bd 100644
--- a/pkg/unet/unet.go
+++ b/pkg/unet/unet.go
@@ -23,6 +23,7 @@ import (
"sync/atomic"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/eventfd"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -55,15 +56,6 @@ func socket(packet bool) (int, error) {
return fd, nil
}
-// eventFD returns a new event FD with initial value 0.
-func eventFD() (int, error) {
- f, _, e := unix.Syscall(unix.SYS_EVENTFD2, 0, 0, 0)
- if e != 0 {
- return -1, e
- }
- return int(f), nil
-}
-
// Socket is a connected unix domain socket.
type Socket struct {
// gate protects use of fd.
@@ -78,7 +70,7 @@ type Socket struct {
// efd is an event FD that is signaled when the socket is closing.
//
// efd is immutable and remains valid until Close/Release.
- efd int
+ efd eventfd.Eventfd
// race is an atomic variable used to avoid triggering the race
// detector. See comment in SocketPair below.
@@ -95,7 +87,7 @@ func NewSocket(fd int) (*Socket, error) {
return nil, err
}
- efd, err := eventFD()
+ efd, err := eventfd.Create()
if err != nil {
return nil, err
}
@@ -110,16 +102,14 @@ func NewSocket(fd int) (*Socket, error) {
// closing the event FD.
func (s *Socket) finish() error {
// Signal any blocked or future polls.
- //
- // N.B. eventfd writes must be 8 bytes.
- if _, err := unix.Write(s.efd, []byte{1, 0, 0, 0, 0, 0, 0, 0}); err != nil {
+ if err := s.efd.Notify(); err != nil {
return err
}
// Close the gate, blocking until all FD users leave.
s.gate.Close()
- return unix.Close(s.efd)
+ return s.efd.Close()
}
// Close closes the socket.
diff --git a/pkg/unet/unet_unsafe.go b/pkg/unet/unet_unsafe.go
index f0bf93ddd..ea281fec3 100644
--- a/pkg/unet/unet_unsafe.go
+++ b/pkg/unet/unet_unsafe.go
@@ -43,7 +43,7 @@ func (s *Socket) wait(write bool) error {
},
{
// The eventfd, signaled when we are closing.
- Fd: int32(s.efd),
+ Fd: int32(s.efd.FD()),
Events: unix.POLLIN,
},
}
diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go
index cde1038ed..f46a00e42 100644
--- a/pkg/usermem/usermem.go
+++ b/pkg/usermem/usermem.go
@@ -429,7 +429,7 @@ type IOSequence struct {
// return 0, nil
// }
// if f.availableBytes == 0 {
-// return 0, syserror.ErrWouldBlock
+// return 0, linuxerr.ErrWouldBlock
// }
// return ioseq.CopyOutFrom(..., reader)
//