summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/BUILD3
-rw-r--r--pkg/abi/linux/fcntl.go41
-rw-r--r--pkg/abi/linux/file.go52
-rw-r--r--pkg/abi/linux/file_amd64.go34
-rw-r--r--pkg/abi/linux/file_arm64.go35
-rw-r--r--pkg/abi/linux/fs.go7
-rw-r--r--pkg/abi/linux/netlink_route.go6
-rw-r--r--pkg/abi/linux/socket.go21
-rw-r--r--pkg/abi/linux/xattr.go27
-rw-r--r--pkg/cpuid/cpuid.go48
-rw-r--r--pkg/eventchannel/BUILD1
-rw-r--r--pkg/fd/BUILD3
-rw-r--r--pkg/fd/fd.go8
-rw-r--r--pkg/flipcall/BUILD2
-rw-r--r--pkg/flipcall/ctrl_futex.go2
-rw-r--r--pkg/flipcall/flipcall.go10
-rw-r--r--pkg/flipcall/flipcall_unsafe.go10
-rw-r--r--pkg/flipcall/futex_linux.go6
-rw-r--r--pkg/p9/client_file.go10
-rw-r--r--pkg/p9/file.go2
-rw-r--r--pkg/p9/handlers.go54
-rw-r--r--pkg/p9/p9.go41
-rw-r--r--pkg/p9/p9test/client_test.go78
-rw-r--r--pkg/p9/server.go6
-rw-r--r--pkg/p9/version.go8
-rw-r--r--pkg/procid/procid_amd64.s2
-rw-r--r--pkg/procid/procid_arm64.s2
-rw-r--r--pkg/seccomp/seccomp.go17
-rw-r--r--pkg/seccomp/seccomp_rules.go3
-rw-r--r--pkg/seccomp/seccomp_test.go48
-rw-r--r--pkg/seccomp/seccomp_test_victim.go2
-rw-r--r--pkg/sentry/arch/arch_amd64.go4
-rw-r--r--pkg/sentry/context/context.go63
-rw-r--r--pkg/sentry/control/pprof.go15
-rw-r--r--pkg/sentry/control/proc.go48
-rw-r--r--pkg/sentry/control/proc_test.go10
-rw-r--r--pkg/sentry/fs/BUILD2
-rw-r--r--pkg/sentry/fs/fdpipe/pipe_opener_test.go1
-rw-r--r--pkg/sentry/fs/flags.go7
-rw-r--r--pkg/sentry/fs/gofer/file_state.go8
-rw-r--r--pkg/sentry/fs/gofer/handles.go5
-rw-r--r--pkg/sentry/fs/gofer/inode.go20
-rw-r--r--pkg/sentry/fs/gofer/path.go4
-rw-r--r--pkg/sentry/fs/gofer/session.go6
-rw-r--r--pkg/sentry/fs/host/BUILD2
-rw-r--r--pkg/sentry/fs/host/util.go2
-rw-r--r--pkg/sentry/fs/host/util_amd64_unsafe.go41
-rw-r--r--pkg/sentry/fs/host/util_arm64_unsafe.go41
-rw-r--r--pkg/sentry/fs/host/util_unsafe.go19
-rw-r--r--pkg/sentry/fs/inode.go12
-rw-r--r--pkg/sentry/fs/inode_overlay.go11
-rw-r--r--pkg/sentry/fs/overlay.go4
-rw-r--r--pkg/sentry/fs/proc/BUILD1
-rw-r--r--pkg/sentry/fs/proc/net.go195
-rw-r--r--pkg/sentry/fs/proc/sys_net.go17
-rw-r--r--pkg/sentry/fs/proc/task.go38
-rw-r--r--pkg/sentry/fs/tty/master.go5
-rw-r--r--pkg/sentry/fs/tty/slave.go5
-rw-r--r--pkg/sentry/fs/tty/terminal.go4
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD1
-rw-r--r--pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go8
-rw-r--r--pkg/sentry/fsimpl/ext/block_map_file.go4
-rw-r--r--pkg/sentry/fsimpl/ext/block_map_test.go4
-rw-r--r--pkg/sentry/fsimpl/ext/dentry.go10
-rw-r--r--pkg/sentry/fsimpl/ext/ext.go10
-rw-r--r--pkg/sentry/fsimpl/ext/ext_test.go43
-rw-r--r--pkg/sentry/fsimpl/ext/extent_file.go6
-rw-r--r--pkg/sentry/fsimpl/ext/extent_test.go4
-rw-r--r--pkg/sentry/fsimpl/ext/file_description.go23
-rw-r--r--pkg/sentry/fsimpl/ext/filesystem.go44
-rw-r--r--pkg/sentry/fsimpl/ext/inode.go31
-rw-r--r--pkg/sentry/fsimpl/kernfs/BUILD60
-rw-r--r--pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go117
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go193
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go743
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go492
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go405
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go423
-rw-r--r--pkg/sentry/fsimpl/memfs/BUILD22
-rw-r--r--pkg/sentry/fsimpl/memfs/benchmark_test.go28
-rw-r--r--pkg/sentry/fsimpl/memfs/filesystem.go129
-rw-r--r--pkg/sentry/fsimpl/memfs/memfs.go32
-rw-r--r--pkg/sentry/fsimpl/memfs/named_pipe.go62
-rw-r--r--pkg/sentry/fsimpl/memfs/pipe_test.go233
-rw-r--r--pkg/sentry/fsimpl/proc/filesystems.go2
-rw-r--r--pkg/sentry/fsimpl/proc/mounts.go2
-rw-r--r--pkg/sentry/fsimpl/proc/task.go4
-rw-r--r--pkg/sentry/inet/BUILD5
-rw-r--r--pkg/sentry/inet/inet.go32
-rw-r--r--pkg/sentry/inet/test_stack.go16
-rw-r--r--pkg/sentry/kernel/BUILD4
-rw-r--r--pkg/sentry/kernel/auth/BUILD2
-rw-r--r--pkg/sentry/kernel/context.go32
-rw-r--r--pkg/sentry/kernel/fd_table.go22
-rw-r--r--pkg/sentry/kernel/fd_table_test.go36
-rw-r--r--pkg/sentry/kernel/futex/BUILD2
-rw-r--r--pkg/sentry/kernel/kernel.go48
-rw-r--r--pkg/sentry/kernel/pipe/BUILD3
-rw-r--r--pkg/sentry/kernel/pipe/node.go72
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go24
-rw-r--r--pkg/sentry/kernel/pipe/pipe_util.go213
-rw-r--r--pkg/sentry/kernel/pipe/reader_writer.go111
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go220
-rw-r--r--pkg/sentry/kernel/ptrace_arm64.go1
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore.go6
-rw-r--r--pkg/sentry/kernel/signalfd/BUILD4
-rw-r--r--pkg/sentry/kernel/syscalls.go8
-rw-r--r--pkg/sentry/kernel/task.go28
-rw-r--r--pkg/sentry/kernel/task_block.go8
-rw-r--r--pkg/sentry/kernel/task_clone.go1
-rw-r--r--pkg/sentry/kernel/task_context.go27
-rw-r--r--pkg/sentry/kernel/task_exec.go3
-rw-r--r--pkg/sentry/kernel/task_exit.go1
-rw-r--r--pkg/sentry/kernel/task_log.go86
-rw-r--r--pkg/sentry/kernel/task_run.go14
-rw-r--r--pkg/sentry/kernel/task_start.go8
-rw-r--r--pkg/sentry/kernel/task_syscall.go8
-rw-r--r--pkg/sentry/kernel/tty.go11
-rw-r--r--pkg/sentry/loader/elf.go21
-rw-r--r--pkg/sentry/loader/loader.go263
-rw-r--r--pkg/sentry/loader/vdso.go2
-rw-r--r--pkg/sentry/memmap/BUILD1
-rw-r--r--pkg/sentry/memmap/memmap.go8
-rw-r--r--pkg/sentry/mm/BUILD2
-rw-r--r--pkg/sentry/mm/mm.go6
-rw-r--r--pkg/sentry/pgalloc/save_restore.go13
-rw-r--r--pkg/sentry/platform/kvm/BUILD13
-rw-r--r--pkg/sentry/platform/kvm/address_space.go2
-rw-r--r--pkg/sentry/platform/kvm/allocator.go2
-rw-r--r--pkg/sentry/platform/kvm/bluepill.go24
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64.go20
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go7
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.go79
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.s87
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go28
-rw-r--r--pkg/sentry/platform/kvm/bluepill_fault.go10
-rw-r--r--pkg/sentry/platform/kvm/bluepill_unsafe.go23
-rw-r--r--pkg/sentry/platform/kvm/filters_amd64.go (renamed from pkg/sentry/platform/kvm/filters.go)0
-rw-r--r--pkg/sentry/platform/kvm/filters_arm64.go32
-rw-r--r--pkg/sentry/platform/kvm/kvm.go5
-rw-r--r--pkg/sentry/platform/kvm/kvm_amd64.go9
-rw-r--r--pkg/sentry/platform/kvm/kvm_arm64.go83
-rw-r--r--pkg/sentry/platform/kvm/kvm_arm64_unsafe.go39
-rw-r--r--pkg/sentry/platform/kvm/kvm_const.go11
-rw-r--r--pkg/sentry/platform/kvm/kvm_const_arm64.go132
-rw-r--r--pkg/sentry/platform/kvm/machine.go26
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go10
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64_unsafe.go64
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go183
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go362
-rw-r--r--pkg/sentry/platform/kvm/machine_unsafe.go66
-rw-r--r--pkg/sentry/platform/kvm/physical_map.go12
-rw-r--r--pkg/sentry/platform/kvm/physical_map_amd64.go22
-rw-r--r--pkg/sentry/platform/kvm/physical_map_arm64.go19
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil_arm64.s15
-rw-r--r--pkg/sentry/platform/ptrace/BUILD1
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go17
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_amd64.go48
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_arm64.go39
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux.go39
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go21
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_unsafe.go2
-rw-r--r--pkg/sentry/platform/ring0/BUILD42
-rw-r--r--pkg/sentry/platform/ring0/aarch64.go109
-rw-r--r--pkg/sentry/platform/ring0/defs.go11
-rw-r--r--pkg/sentry/platform/ring0/defs_amd64.go11
-rw-r--r--pkg/sentry/platform/ring0/defs_arm64.go136
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.go60
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s591
-rw-r--r--pkg/sentry/platform/ring0/gen_offsets/BUILD16
-rw-r--r--pkg/sentry/platform/ring0/kernel_arm64.go58
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.go39
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.s118
-rw-r--r--pkg/sentry/platform/ring0/offsets_arm64.go125
-rw-r--r--pkg/sentry/platform/ring0/pagetables/BUILD16
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables.go9
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go212
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go9
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go57
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go80
-rw-r--r--pkg/sentry/platform/ring0/pagetables/walker_arm64.go314
-rw-r--r--pkg/sentry/sighandling/sighandling.go75
-rw-r--r--pkg/sentry/sighandling/sighandling_unsafe.go26
-rw-r--r--pkg/sentry/socket/control/BUILD1
-rw-r--r--pkg/sentry/socket/control/control.go209
-rw-r--r--pkg/sentry/socket/hostinet/BUILD3
-rw-r--r--pkg/sentry/socket/hostinet/socket.go110
-rw-r--r--pkg/sentry/socket/hostinet/stack.go116
-rw-r--r--pkg/sentry/socket/netlink/BUILD1
-rw-r--r--pkg/sentry/socket/netlink/provider.go7
-rw-r--r--pkg/sentry/socket/netlink/route/protocol.go5
-rw-r--r--pkg/sentry/socket/netlink/socket.go118
-rw-r--r--pkg/sentry/socket/netlink/uevent/BUILD17
-rw-r--r--pkg/sentry/socket/netlink/uevent/protocol.go60
-rw-r--r--pkg/sentry/socket/netstack/netstack.go99
-rw-r--r--pkg/sentry/socket/netstack/provider.go53
-rw-r--r--pkg/sentry/socket/netstack/stack.go108
-rw-r--r--pkg/sentry/socket/rpcinet/BUILD1
-rw-r--r--pkg/sentry/socket/rpcinet/stack.go10
-rw-r--r--pkg/sentry/socket/rpcinet/syscall_rpc.proto1
-rw-r--r--pkg/sentry/socket/socket.go5
-rw-r--r--pkg/sentry/socket/unix/unix.go3
-rw-r--r--pkg/sentry/strace/BUILD1
-rw-r--r--pkg/sentry/strace/linux64.go32
-rw-r--r--pkg/sentry/strace/select.go56
-rw-r--r--pkg/sentry/strace/socket.go9
-rw-r--r--pkg/sentry/strace/strace.go2
-rw-r--r--pkg/sentry/strace/strace.proto3
-rw-r--r--pkg/sentry/strace/syscalls.go4
-rw-r--r--pkg/sentry/syscalls/linux/BUILD2
-rw-r--r--pkg/sentry/syscalls/linux/flags.go1
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go11
-rw-r--r--pkg/sentry/syscalls/linux/linux64_amd64.go38
-rw-r--r--pkg/sentry/syscalls/linux/linux64_arm64.go35
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go91
-rw-r--r--pkg/sentry/syscalls/linux/sys_futex.go10
-rw-r--r--pkg/sentry/syscalls/linux/sys_poll.go71
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go41
-rw-r--r--pkg/sentry/syscalls/linux/sys_thread.go81
-rw-r--r--pkg/sentry/syscalls/linux/sys_write.go1
-rw-r--r--pkg/sentry/syscalls/linux/sys_xattr.go169
-rw-r--r--pkg/sentry/time/BUILD4
-rw-r--r--pkg/sentry/usermem/bytes_io.go37
-rw-r--r--pkg/sentry/vfs/BUILD5
-rw-r--r--pkg/sentry/vfs/README.md4
-rw-r--r--pkg/sentry/vfs/context.go13
-rw-r--r--pkg/sentry/vfs/dentry.go162
-rw-r--r--pkg/sentry/vfs/file_description.go310
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go34
-rw-r--r--pkg/sentry/vfs/file_description_impl_util_test.go14
-rw-r--r--pkg/sentry/vfs/filesystem.go110
-rw-r--r--pkg/sentry/vfs/filesystem_impl_util.go69
-rw-r--r--pkg/sentry/vfs/filesystem_type.go10
-rw-r--r--pkg/sentry/vfs/mount.go422
-rw-r--r--pkg/sentry/vfs/mount_test.go34
-rw-r--r--pkg/sentry/vfs/mount_unsafe.go66
-rw-r--r--pkg/sentry/vfs/options.go26
-rw-r--r--pkg/sentry/vfs/pathname.go153
-rw-r--r--pkg/sentry/vfs/permissions.go62
-rw-r--r--pkg/sentry/vfs/resolving_path.go56
-rw-r--r--pkg/sentry/vfs/syscalls.go217
-rw-r--r--pkg/sentry/vfs/testutil.go41
-rw-r--r--pkg/sentry/vfs/vfs.go484
-rw-r--r--pkg/sentry/watchdog/watchdog.go156
-rw-r--r--pkg/sleep/sleep_unsafe.go2
-rw-r--r--pkg/state/decode.go4
-rw-r--r--pkg/state/encode.go4
-rw-r--r--pkg/state/map.go11
-rw-r--r--pkg/state/object.proto56
-rw-r--r--pkg/state/state.go7
-rw-r--r--pkg/state/state_test.go11
-rw-r--r--pkg/syncutil/BUILD52
-rw-r--r--pkg/syncutil/LICENSE27
-rw-r--r--pkg/syncutil/README.md5
-rw-r--r--pkg/syncutil/atomicptr_unsafe.go47
-rw-r--r--pkg/syncutil/atomicptrtest/BUILD29
-rw-r--r--pkg/syncutil/atomicptrtest/atomicptr_test.go31
-rw-r--r--pkg/syncutil/downgradable_rwmutex_test.go150
-rw-r--r--pkg/syncutil/downgradable_rwmutex_unsafe.go146
-rw-r--r--pkg/syncutil/memmove_unsafe.go28
-rw-r--r--pkg/syncutil/norace_unsafe.go35
-rw-r--r--pkg/syncutil/race_unsafe.go41
-rw-r--r--pkg/syncutil/seqatomic_unsafe.go72
-rw-r--r--pkg/syncutil/seqatomictest/BUILD35
-rw-r--r--pkg/syncutil/seqatomictest/seqatomic_test.go132
-rw-r--r--pkg/syncutil/seqcount.go149
-rw-r--r--pkg/syncutil/seqcount_test.go153
-rw-r--r--pkg/syncutil/syncutil.go7
-rw-r--r--pkg/tcpip/BUILD2
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go12
-rw-r--r--pkg/tcpip/buffer/prependable.go11
-rw-r--r--pkg/tcpip/checker/BUILD1
-rw-r--r--pkg/tcpip/checker/checker.go7
-rw-r--r--pkg/tcpip/hash/jenkins/BUILD4
-rw-r--r--pkg/tcpip/header/BUILD20
-rw-r--r--pkg/tcpip/header/checksum.go50
-rw-r--r--pkg/tcpip/header/checksum_test.go109
-rw-r--r--pkg/tcpip/header/eth.go62
-rw-r--r--pkg/tcpip/header/eth_test.go68
-rw-r--r--pkg/tcpip/header/icmpv6.go11
-rw-r--r--pkg/tcpip/header/ipv6.go67
-rw-r--r--pkg/tcpip/header/ipv6_test.go45
-rw-r--r--pkg/tcpip/header/ndp_neighbor_advert.go2
-rw-r--r--pkg/tcpip/header/ndp_neighbor_solicit.go2
-rw-r--r--pkg/tcpip/header/ndp_options.go441
-rw-r--r--pkg/tcpip/header/ndp_router_advert.go112
-rw-r--r--pkg/tcpip/header/ndp_test.go647
-rw-r--r--pkg/tcpip/link/channel/BUILD2
-rw-r--r--pkg/tcpip/link/channel/channel.go72
-rw-r--r--pkg/tcpip/link/fdbased/BUILD4
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go171
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go32
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go8
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go36
-rw-r--r--pkg/tcpip/link/loopback/BUILD3
-rw-r--r--pkg/tcpip/link/loopback/loopback.go47
-rw-r--r--pkg/tcpip/link/muxed/BUILD4
-rw-r--r--pkg/tcpip/link/muxed/injectable.go34
-rw-r--r--pkg/tcpip/link/muxed/injectable_test.go14
-rw-r--r--pkg/tcpip/link/rawfile/BUILD7
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go2
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go11
-rw-r--r--pkg/tcpip/link/sharedmem/BUILD4
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/BUILD2
-rw-r--r--pkg/tcpip/link/sharedmem/queue/BUILD2
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go36
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go70
-rw-r--r--pkg/tcpip/link/sniffer/BUILD4
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go118
-rw-r--r--pkg/tcpip/link/tun/BUILD4
-rw-r--r--pkg/tcpip/link/waitable/BUILD4
-rw-r--r--pkg/tcpip/link/waitable/waitable.go32
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go27
-rw-r--r--pkg/tcpip/network/arp/BUILD4
-rw-r--r--pkg/tcpip/network/arp/arp.go47
-rw-r--r--pkg/tcpip/network/arp/arp_test.go14
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD10
-rw-r--r--pkg/tcpip/network/ip_test.go66
-rw-r--r--pkg/tcpip/network/ipv4/BUILD4
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go40
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go179
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go53
-rw-r--r--pkg/tcpip/network/ipv6/BUILD4
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go149
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go564
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go77
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go8
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go199
-rw-r--r--pkg/tcpip/packet_buffer.go67
-rw-r--r--pkg/tcpip/packet_buffer_state.go27
-rw-r--r--pkg/tcpip/ports/BUILD4
-rw-r--r--pkg/tcpip/ports/ports.go148
-rw-r--r--pkg/tcpip/ports/ports_test.go182
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/BUILD1
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/BUILD1
-rw-r--r--pkg/tcpip/seqnum/BUILD4
-rw-r--r--pkg/tcpip/stack/BUILD13
-rw-r--r--pkg/tcpip/stack/ndp.go960
-rw-r--r--pkg/tcpip/stack/ndp_test.go1781
-rw-r--r--pkg/tcpip/stack/nic.go355
-rw-r--r--pkg/tcpip/stack/registration.go177
-rw-r--r--pkg/tcpip/stack/route.go39
-rw-r--r--pkg/tcpip/stack/stack.go387
-rw-r--r--pkg/tcpip/stack/stack_test.go454
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go224
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go12
-rw-r--r--pkg/tcpip/stack/transport_test.go62
-rw-r--r--pkg/tcpip/tcpip.go72
-rw-r--r--pkg/tcpip/time_unsafe.go2
-rw-r--r--pkg/tcpip/transport/icmp/BUILD8
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go73
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go2
-rw-r--r--pkg/tcpip/transport/packet/BUILD38
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go362
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go72
-rw-r--r--pkg/tcpip/transport/raw/BUILD21
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go47
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go14
-rw-r--r--pkg/tcpip/transport/raw/protocol.go12
-rw-r--r--pkg/tcpip/transport/tcp/BUILD13
-rw-r--r--pkg/tcpip/transport/tcp/accept.go69
-rw-r--r--pkg/tcpip/transport/tcp/connect.go556
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go308
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go33
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go5
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go86
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go190
-rw-r--r--pkg/tcpip/transport/tcp/rcv_state.go29
-rw-r--r--pkg/tcpip/transport/tcp/segment.go15
-rw-r--r--pkg/tcpip/transport/tcp/snd.go48
-rw-r--r--pkg/tcpip/transport/tcp/snd_state.go10
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go1961
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/BUILD2
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go77
-rw-r--r--pkg/tcpip/transport/udp/BUILD9
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go146
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go2
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go9
-rw-r--r--pkg/tcpip/transport/udp/protocol.go33
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go33
-rw-r--r--pkg/waiter/BUILD8
381 files changed, 24499 insertions, 3265 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index 7c17109a6..9553f164d 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -23,6 +23,8 @@ go_library(
"exec.go",
"fcntl.go",
"file.go",
+ "file_amd64.go",
+ "file_arm64.go",
"fs.go",
"futex.go",
"inotify.go",
@@ -55,6 +57,7 @@ go_library(
"uio.go",
"utsname.go",
"wait.go",
+ "xattr.go",
],
importpath = "gvisor.dev/gvisor/pkg/abi/linux",
visibility = ["//visibility:public"],
diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go
index f78315ebf..6663a199c 100644
--- a/pkg/abi/linux/fcntl.go
+++ b/pkg/abi/linux/fcntl.go
@@ -16,15 +16,17 @@ package linux
// Commands from linux/fcntl.h.
const (
- F_DUPFD = 0x0
- F_GETFD = 0x1
- F_SETFD = 0x2
- F_GETFL = 0x3
- F_SETFL = 0x4
- F_SETLK = 0x6
- F_SETLKW = 0x7
- F_SETOWN = 0x8
- F_GETOWN = 0x9
+ F_DUPFD = 0
+ F_GETFD = 1
+ F_SETFD = 2
+ F_GETFL = 3
+ F_SETFL = 4
+ F_SETLK = 6
+ F_SETLKW = 7
+ F_SETOWN = 8
+ F_GETOWN = 9
+ F_SETOWN_EX = 15
+ F_GETOWN_EX = 16
F_DUPFD_CLOEXEC = 1024 + 6
F_SETPIPE_SZ = 1024 + 7
F_GETPIPE_SZ = 1024 + 8
@@ -32,9 +34,9 @@ const (
// Commands for F_SETLK.
const (
- F_RDLCK = 0x0
- F_WRLCK = 0x1
- F_UNLCK = 0x2
+ F_RDLCK = 0
+ F_WRLCK = 1
+ F_UNLCK = 2
)
// Flags for fcntl.
@@ -42,7 +44,7 @@ const (
FD_CLOEXEC = 00000001
)
-// Lock structure for F_SETLK.
+// Flock is the lock structure for F_SETLK.
type Flock struct {
Type int16
Whence int16
@@ -52,3 +54,16 @@ type Flock struct {
Pid int32
_ [4]byte
}
+
+// Flags for F_SETOWN_EX and F_GETOWN_EX.
+const (
+ F_OWNER_TID = 0
+ F_OWNER_PID = 1
+ F_OWNER_PGRP = 2
+)
+
+// FOwnerEx is the owner structure for F_SETOWN_EX and F_GETOWN_EX.
+type FOwnerEx struct {
+ Type int32
+ PID int32
+}
diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go
index 257f67222..16791d03e 100644
--- a/pkg/abi/linux/file.go
+++ b/pkg/abi/linux/file.go
@@ -144,9 +144,13 @@ const (
ModeCharacterDevice = S_IFCHR
ModeNamedPipe = S_IFIFO
- ModeSetUID = 04000
- ModeSetGID = 02000
- ModeSticky = 01000
+ S_ISUID = 04000
+ S_ISGID = 02000
+ S_ISVTX = 01000
+
+ ModeSetUID = S_ISUID
+ ModeSetGID = S_ISGID
+ ModeSticky = S_ISVTX
ModeUserAll = 0700
ModeUserRead = 0400
@@ -186,25 +190,6 @@ const (
RWF_VALID = RWF_HIPRI | RWF_DSYNC | RWF_SYNC
)
-// Stat represents struct stat.
-type Stat struct {
- Dev uint64
- Ino uint64
- Nlink uint64
- Mode uint32
- UID uint32
- GID uint32
- _ int32
- Rdev uint64
- Size int64
- Blksize int64
- Blocks int64
- ATime Timespec
- MTime Timespec
- CTime Timespec
- _ [3]int64
-}
-
// SizeOfStat is the size of a Stat struct.
var SizeOfStat = binary.Size(Stat{})
@@ -301,6 +286,29 @@ func (m FileMode) String() string {
return strings.Join(s, "|")
}
+// DirentType maps file types to dirent types appropriate for (struct
+// dirent)::d_type.
+func (m FileMode) DirentType() uint8 {
+ switch m.FileType() {
+ case ModeSocket:
+ return DT_SOCK
+ case ModeSymlink:
+ return DT_LNK
+ case ModeRegular:
+ return DT_REG
+ case ModeBlockDevice:
+ return DT_BLK
+ case ModeDirectory:
+ return DT_DIR
+ case ModeCharacterDevice:
+ return DT_CHR
+ case ModeNamedPipe:
+ return DT_FIFO
+ default:
+ return DT_UNKNOWN
+ }
+}
+
var modeExtraBits = abi.FlagSet{
{
Flag: ModeSetUID,
diff --git a/pkg/abi/linux/file_amd64.go b/pkg/abi/linux/file_amd64.go
new file mode 100644
index 000000000..74c554be6
--- /dev/null
+++ b/pkg/abi/linux/file_amd64.go
@@ -0,0 +1,34 @@
+// Copyright 2018 The gVisor Authors.
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Stat represents struct stat.
+type Stat struct {
+ Dev uint64
+ Ino uint64
+ Nlink uint64
+ Mode uint32
+ UID uint32
+ GID uint32
+ _ int32
+ Rdev uint64
+ Size int64
+ Blksize int64
+ Blocks int64
+ ATime Timespec
+ MTime Timespec
+ CTime Timespec
+ _ [3]int64
+}
diff --git a/pkg/abi/linux/file_arm64.go b/pkg/abi/linux/file_arm64.go
new file mode 100644
index 000000000..f16c07589
--- /dev/null
+++ b/pkg/abi/linux/file_arm64.go
@@ -0,0 +1,35 @@
+// Copyright 2019 The gVisor Authors.
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Stat represents struct stat.
+type Stat struct {
+ Dev uint64
+ Ino uint64
+ Mode uint32
+ Nlink uint32
+ UID uint32
+ GID uint32
+ Rdev uint64
+ _ uint64
+ Size int64
+ Blksize int32
+ _ int32
+ Blocks int64
+ ATime Timespec
+ MTime Timespec
+ CTime Timespec
+ _ [2]int32
+}
diff --git a/pkg/abi/linux/fs.go b/pkg/abi/linux/fs.go
index b416e3472..2c652baa2 100644
--- a/pkg/abi/linux/fs.go
+++ b/pkg/abi/linux/fs.go
@@ -92,3 +92,10 @@ const (
SYNC_FILE_RANGE_WRITE = 2
SYNC_FILE_RANGE_WAIT_AFTER = 4
)
+
+// Flag argument to renameat2(2), from include/uapi/linux/fs.h.
+const (
+ RENAME_NOREPLACE = (1 << 0) // Don't overwrite target.
+ RENAME_EXCHANGE = (1 << 1) // Exchange src and dst.
+ RENAME_WHITEOUT = (1 << 2) // Whiteout src.
+)
diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go
index 152f6b144..3898d2314 100644
--- a/pkg/abi/linux/netlink_route.go
+++ b/pkg/abi/linux/netlink_route.go
@@ -325,3 +325,9 @@ const (
RTA_SPORT = 28
RTA_DPORT = 29
)
+
+// Route flags, from include/uapi/linux/route.h.
+const (
+ RTF_GATEWAY = 0x2
+ RTF_UP = 0x1
+)
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index d5b731390..766ee4014 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -256,6 +256,17 @@ type SockAddrInet6 struct {
Scope_id uint32
}
+// SockAddrLink is a struct sockaddr_ll, from uapi/linux/if_packet.h.
+type SockAddrLink struct {
+ Family uint16
+ Protocol uint16
+ InterfaceIndex int32
+ ARPHardwareType uint16
+ PacketType byte
+ HardwareAddrLen byte
+ HardwareAddr [8]byte
+}
+
// UnixPathMax is the maximum length of the path in an AF_UNIX socket.
//
// From uapi/linux/un.h.
@@ -278,6 +289,7 @@ type SockAddr interface {
func (s *SockAddrInet) implementsSockAddr() {}
func (s *SockAddrInet6) implementsSockAddr() {}
+func (s *SockAddrLink) implementsSockAddr() {}
func (s *SockAddrUnix) implementsSockAddr() {}
func (s *SockAddrNetlink) implementsSockAddr() {}
@@ -410,6 +422,15 @@ type ControlMessageRights []int32
// ControlMessageRights.
const SizeOfControlMessageRight = 4
+// SizeOfControlMessageInq is the size of a TCP_INQ control message.
+const SizeOfControlMessageInq = 4
+
+// SizeOfControlMessageTOS is the size of an IP_TOS control message.
+const SizeOfControlMessageTOS = 1
+
+// SizeOfControlMessageTClass is the size of an IPV6_TCLASS control message.
+const SizeOfControlMessageTClass = 4
+
// SCM_MAX_FD is the maximum number of FDs accepted in a single sendmsg call.
// From net/scm.h.
const SCM_MAX_FD = 253
diff --git a/pkg/abi/linux/xattr.go b/pkg/abi/linux/xattr.go
new file mode 100644
index 000000000..a3b6406fa
--- /dev/null
+++ b/pkg/abi/linux/xattr.go
@@ -0,0 +1,27 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Constants for extended attributes.
+const (
+ XATTR_NAME_MAX = 255
+ XATTR_SIZE_MAX = 65536
+
+ XATTR_CREATE = 1
+ XATTR_REPLACE = 2
+
+ XATTR_USER_PREFIX = "user."
+ XATTR_USER_PREFIX_LEN = len(XATTR_USER_PREFIX)
+)
diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go
index 5d61dc2ff..d37047368 100644
--- a/pkg/cpuid/cpuid.go
+++ b/pkg/cpuid/cpuid.go
@@ -183,6 +183,33 @@ const (
X86FeatureAVX512VBMI
X86FeatureUMIP
X86FeaturePKU
+ X86FeatureOSPKE
+ X86FeatureWAITPKG
+ X86FeatureAVX512_VBMI2
+ _ // ecx bit 7 is reserved
+ X86FeatureGFNI
+ X86FeatureVAES
+ X86FeatureVPCLMULQDQ
+ X86FeatureAVX512_VNNI
+ X86FeatureAVX512_BITALG
+ X86FeatureTME
+ X86FeatureAVX512_VPOPCNTDQ
+ _ // ecx bit 15 is reserved
+ X86FeatureLA57
+ // ecx bits 17-21 are reserved
+ _
+ _
+ _
+ _
+ _
+ X86FeatureRDPID
+ // ecx bits 23-24 are reserved
+ _
+ _
+ X86FeatureCLDEMOTE
+ _ // ecx bit 26 is reserved
+ X86FeatureMOVDIRI
+ X86FeatureMOVDIR64B
)
// Block 4 constants are for xsave capabilities in CPUID.(EAX=0DH,ECX=01H):EAX.
@@ -353,9 +380,24 @@ var x86FeatureStrings = map[Feature]string{
X86FeatureAVX512VL: "avx512vl",
// Block 3.
- X86FeatureAVX512VBMI: "avx512vbmi",
- X86FeatureUMIP: "umip",
- X86FeaturePKU: "pku",
+ X86FeatureAVX512VBMI: "avx512vbmi",
+ X86FeatureUMIP: "umip",
+ X86FeaturePKU: "pku",
+ X86FeatureOSPKE: "ospke",
+ X86FeatureWAITPKG: "waitpkg",
+ X86FeatureAVX512_VBMI2: "avx512_vbmi2",
+ X86FeatureGFNI: "gfni",
+ X86FeatureVAES: "vaes",
+ X86FeatureVPCLMULQDQ: "vpclmulqdq",
+ X86FeatureAVX512_VNNI: "avx512_vnni",
+ X86FeatureAVX512_BITALG: "avx512_bitalg",
+ X86FeatureTME: "tme",
+ X86FeatureAVX512_VPOPCNTDQ: "avx512_vpopcntdq",
+ X86FeatureLA57: "la57",
+ X86FeatureRDPID: "rdpid",
+ X86FeatureCLDEMOTE: "cldemote",
+ X86FeatureMOVDIRI: "movdiri",
+ X86FeatureMOVDIR64B: "movdir64b",
// Block 4.
X86FeatureXSAVEOPT: "xsaveopt",
diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD
index 71f2abc83..0b4b7cc44 100644
--- a/pkg/eventchannel/BUILD
+++ b/pkg/eventchannel/BUILD
@@ -25,6 +25,7 @@ go_library(
proto_library(
name = "eventchannel_proto",
srcs = ["event.proto"],
+ visibility = ["//:sandbox"],
)
go_proto_library(
diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD
index c7f549428..afa8f7659 100644
--- a/pkg/fd/BUILD
+++ b/pkg/fd/BUILD
@@ -8,9 +8,6 @@ go_library(
srcs = ["fd.go"],
importpath = "gvisor.dev/gvisor/pkg/fd",
visibility = ["//visibility:public"],
- deps = [
- "//pkg/unet",
- ],
)
go_test(
diff --git a/pkg/fd/fd.go b/pkg/fd/fd.go
index 7691b477b..83bcfe220 100644
--- a/pkg/fd/fd.go
+++ b/pkg/fd/fd.go
@@ -22,8 +22,6 @@ import (
"runtime"
"sync/atomic"
"syscall"
-
- "gvisor.dev/gvisor/pkg/unet"
)
// ReadWriter implements io.ReadWriter, io.ReaderAt, and io.WriterAt for fd. It
@@ -187,12 +185,6 @@ func OpenAt(dir *FD, path string, flags int, mode uint32) (*FD, error) {
return New(f), nil
}
-// DialUnix connects to a Unix Domain Socket and return the file descriptor.
-func DialUnix(path string) (*FD, error) {
- socket, err := unet.Connect(path, false)
- return New(socket.FD()), err
-}
-
// Close closes the file descriptor contained in the FD.
//
// Close is safe to call multiple times, but will return an error after the
diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD
index 5643d5f26..e590a71ba 100644
--- a/pkg/flipcall/BUILD
+++ b/pkg/flipcall/BUILD
@@ -19,7 +19,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/log",
"//pkg/memutil",
- "//third_party/gvsync",
+ "//pkg/syncutil",
],
)
diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go
index 8390915a2..e7c3a3a0b 100644
--- a/pkg/flipcall/ctrl_futex.go
+++ b/pkg/flipcall/ctrl_futex.go
@@ -113,7 +113,7 @@ func (ep *Endpoint) enterFutexWait() error {
return nil
case epsBlocked | epsShutdown:
atomic.AddInt32(&ep.ctrl.state, -epsBlocked)
- return shutdownError{}
+ return ShutdownError{}
default:
// Most likely due to ep.enterFutexWait() being called concurrently
// from multiple goroutines.
diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go
index 386cee42c..3cdb576e1 100644
--- a/pkg/flipcall/flipcall.go
+++ b/pkg/flipcall/flipcall.go
@@ -136,8 +136,8 @@ func (ep *Endpoint) unmapPacket() {
// Shutdown causes concurrent and future calls to ep.Connect(), ep.SendRecv(),
// ep.RecvFirst(), and ep.SendLast(), as well as the same calls in the peer
-// Endpoint, to unblock and return errors. It does not wait for concurrent
-// calls to return. Successive calls to Shutdown have no effect.
+// Endpoint, to unblock and return ShutdownErrors. It does not wait for
+// concurrent calls to return. Successive calls to Shutdown have no effect.
//
// Shutdown is the only Endpoint method that may be called concurrently with
// other methods on the same Endpoint.
@@ -154,10 +154,12 @@ func (ep *Endpoint) isShutdownLocally() bool {
return atomic.LoadUint32(&ep.shutdown) != 0
}
-type shutdownError struct{}
+// ShutdownError is returned by most Endpoint methods after Endpoint.Shutdown()
+// has been called.
+type ShutdownError struct{}
// Error implements error.Error.
-func (shutdownError) Error() string {
+func (ShutdownError) Error() string {
return "flipcall connection shutdown"
}
diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go
index a37952637..27b8939fc 100644
--- a/pkg/flipcall/flipcall_unsafe.go
+++ b/pkg/flipcall/flipcall_unsafe.go
@@ -18,7 +18,7 @@ import (
"reflect"
"unsafe"
- "gvisor.dev/gvisor/third_party/gvsync"
+ "gvisor.dev/gvisor/pkg/syncutil"
)
// Packets consist of a 16-byte header followed by an arbitrarily-sized
@@ -75,13 +75,13 @@ func (ep *Endpoint) Data() []byte {
var ioSync int64
func raceBecomeActive() {
- if gvsync.RaceEnabled {
- gvsync.RaceAcquire((unsafe.Pointer)(&ioSync))
+ if syncutil.RaceEnabled {
+ syncutil.RaceAcquire((unsafe.Pointer)(&ioSync))
}
}
func raceBecomeInactive() {
- if gvsync.RaceEnabled {
- gvsync.RaceReleaseMerge((unsafe.Pointer)(&ioSync))
+ if syncutil.RaceEnabled {
+ syncutil.RaceReleaseMerge((unsafe.Pointer)(&ioSync))
}
}
diff --git a/pkg/flipcall/futex_linux.go b/pkg/flipcall/futex_linux.go
index b127a2bbb..168c1ccff 100644
--- a/pkg/flipcall/futex_linux.go
+++ b/pkg/flipcall/futex_linux.go
@@ -61,7 +61,7 @@ func (ep *Endpoint) futexSwitchToPeer() error {
if !atomic.CompareAndSwapUint32(ep.connState(), ep.activeState, ep.inactiveState) {
switch cs := atomic.LoadUint32(ep.connState()); cs {
case csShutdown:
- return shutdownError{}
+ return ShutdownError{}
default:
return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", cs)
}
@@ -81,14 +81,14 @@ func (ep *Endpoint) futexSwitchFromPeer() error {
return nil
case ep.inactiveState:
if ep.isShutdownLocally() {
- return shutdownError{}
+ return ShutdownError{}
}
if err := ep.futexWaitConnState(ep.inactiveState); err != nil {
return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err)
}
continue
case csShutdown:
- return shutdownError{}
+ return ShutdownError{}
default:
return fmt.Errorf("unexpected connection state before FUTEX_WAIT: %v", cs)
}
diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go
index a6cc0617e..de9357389 100644
--- a/pkg/p9/client_file.go
+++ b/pkg/p9/client_file.go
@@ -17,7 +17,6 @@ package p9
import (
"fmt"
"io"
- "runtime"
"sync/atomic"
"syscall"
@@ -45,15 +44,10 @@ func (c *Client) Attach(name string) (File, error) {
// newFile returns a new client file.
func (c *Client) newFile(fid FID) *clientFile {
- cf := &clientFile{
+ return &clientFile{
client: c,
fid: fid,
}
-
- // Make sure the file is closed.
- runtime.SetFinalizer(cf, (*clientFile).Close)
-
- return cf
}
// clientFile is provided to clients.
@@ -192,7 +186,6 @@ func (c *clientFile) Remove() error {
if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
return syscall.EBADF
}
- runtime.SetFinalizer(c, nil)
// Send the remove message.
if err := c.client.sendRecv(&Tremove{FID: c.fid}, &Rremove{}); err != nil {
@@ -214,7 +207,6 @@ func (c *clientFile) Close() error {
if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
return syscall.EBADF
}
- runtime.SetFinalizer(c, nil)
// Send the close message.
if err := c.client.sendRecv(&Tclunk{FID: c.fid}, &Rclunk{}); err != nil {
diff --git a/pkg/p9/file.go b/pkg/p9/file.go
index 907445e15..96d1f2a8e 100644
--- a/pkg/p9/file.go
+++ b/pkg/p9/file.go
@@ -116,7 +116,7 @@ type File interface {
// N.B. The server must resolve any lazy paths when open is called.
// After this point, read and write may be called on files with no
// deletion check, so resolving in the data path is not viable.
- Open(mode OpenFlags) (*fd.FD, QID, uint32, error)
+ Open(flags OpenFlags) (*fd.FD, QID, uint32, error)
// Read reads from this file. Open must be called first.
//
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
index ba9a55d6d..b9582c07f 100644
--- a/pkg/p9/handlers.go
+++ b/pkg/p9/handlers.go
@@ -257,7 +257,6 @@ func CanOpen(mode FileMode) bool {
// handle implements handler.handle.
func (t *Tlopen) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -272,15 +271,15 @@ func (t *Tlopen) handle(cs *connState) message {
return newErr(syscall.EINVAL)
}
- // Are flags valid?
- flags := t.Flags &^ OpenFlagsIgnoreMask
- if flags&^OpenFlagsModeMask != 0 {
- return newErr(syscall.EINVAL)
- }
-
- // Is this an attempt to open a directory as writable? Don't accept.
- if ref.mode.IsDir() && flags != ReadOnly {
- return newErr(syscall.EINVAL)
+ if ref.mode.IsDir() {
+ // Directory must be opened ReadOnly.
+ if t.Flags&OpenFlagsModeMask != ReadOnly {
+ return newErr(syscall.EISDIR)
+ }
+ // Directory not truncatable.
+ if t.Flags&OpenTruncate != 0 {
+ return newErr(syscall.EISDIR)
+ }
}
var (
@@ -294,7 +293,6 @@ func (t *Tlopen) handle(cs *connState) message {
return syscall.EINVAL
}
- // Do the open.
osFile, qid, ioUnit, err = ref.file.Open(t.Flags)
return err
}); err != nil {
@@ -311,12 +309,10 @@ func (t *Tlopen) handle(cs *connState) message {
}
func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) {
- // Don't allow complex names.
if err := checkSafeName(t.Name); err != nil {
return nil, err
}
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return nil, syscall.EBADF
@@ -390,12 +386,10 @@ func (t *Tsymlink) handle(cs *connState) message {
}
func (t *Tsymlink) do(cs *connState, uid UID) (*Rsymlink, error) {
- // Don't allow complex names.
if err := checkSafeName(t.Name); err != nil {
return nil, err
}
- // Lookup the FID.
ref, ok := cs.LookupFID(t.Directory)
if !ok {
return nil, syscall.EBADF
@@ -426,19 +420,16 @@ func (t *Tsymlink) do(cs *connState, uid UID) (*Rsymlink, error) {
// handle implements handler.handle.
func (t *Tlink) handle(cs *connState) message {
- // Don't allow complex names.
if err := checkSafeName(t.Name); err != nil {
return newErr(err)
}
- // Lookup the FID.
ref, ok := cs.LookupFID(t.Directory)
if !ok {
return newErr(syscall.EBADF)
}
defer ref.DecRef()
- // Lookup the other FID.
refTarget, ok := cs.LookupFID(t.Target)
if !ok {
return newErr(syscall.EBADF)
@@ -467,7 +458,6 @@ func (t *Tlink) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Trenameat) handle(cs *connState) message {
- // Don't allow complex names.
if err := checkSafeName(t.OldName); err != nil {
return newErr(err)
}
@@ -475,14 +465,12 @@ func (t *Trenameat) handle(cs *connState) message {
return newErr(err)
}
- // Lookup the FID.
ref, ok := cs.LookupFID(t.OldDirectory)
if !ok {
return newErr(syscall.EBADF)
}
defer ref.DecRef()
- // Lookup the other FID.
refTarget, ok := cs.LookupFID(t.NewDirectory)
if !ok {
return newErr(syscall.EBADF)
@@ -523,12 +511,10 @@ func (t *Trenameat) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Tunlinkat) handle(cs *connState) message {
- // Don't allow complex names.
if err := checkSafeName(t.Name); err != nil {
return newErr(err)
}
- // Lookup the FID.
ref, ok := cs.LookupFID(t.Directory)
if !ok {
return newErr(syscall.EBADF)
@@ -577,19 +563,16 @@ func (t *Tunlinkat) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Trename) handle(cs *connState) message {
- // Don't allow complex names.
if err := checkSafeName(t.Name); err != nil {
return newErr(err)
}
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
}
defer ref.DecRef()
- // Lookup the target.
refTarget, ok := cs.LookupFID(t.Directory)
if !ok {
return newErr(syscall.EBADF)
@@ -641,7 +624,6 @@ func (t *Trename) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Treadlink) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -669,7 +651,6 @@ func (t *Treadlink) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Tread) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -708,7 +689,6 @@ func (t *Tread) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Twrite) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -747,12 +727,10 @@ func (t *Tmknod) handle(cs *connState) message {
}
func (t *Tmknod) do(cs *connState, uid UID) (*Rmknod, error) {
- // Don't allow complex names.
if err := checkSafeName(t.Name); err != nil {
return nil, err
}
- // Lookup the FID.
ref, ok := cs.LookupFID(t.Directory)
if !ok {
return nil, syscall.EBADF
@@ -791,12 +769,10 @@ func (t *Tmkdir) handle(cs *connState) message {
}
func (t *Tmkdir) do(cs *connState, uid UID) (*Rmkdir, error) {
- // Don't allow complex names.
if err := checkSafeName(t.Name); err != nil {
return nil, err
}
- // Lookup the FID.
ref, ok := cs.LookupFID(t.Directory)
if !ok {
return nil, syscall.EBADF
@@ -827,7 +803,6 @@ func (t *Tmkdir) do(cs *connState, uid UID) (*Rmkdir, error) {
// handle implements handler.handle.
func (t *Tgetattr) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -856,7 +831,6 @@ func (t *Tgetattr) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Tsetattr) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -883,7 +857,6 @@ func (t *Tsetattr) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Tallocate) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -917,7 +890,6 @@ func (t *Tallocate) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Txattrwalk) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -930,7 +902,6 @@ func (t *Txattrwalk) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Txattrcreate) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -943,7 +914,6 @@ func (t *Txattrcreate) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Treaddir) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.Directory)
if !ok {
return newErr(syscall.EBADF)
@@ -977,7 +947,6 @@ func (t *Treaddir) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Tfsync) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -1001,7 +970,6 @@ func (t *Tfsync) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Tstatfs) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -1192,7 +1160,6 @@ func doWalk(cs *connState, ref *fidRef, names []string, getattr bool) (qids []QI
// handle implements handler.handle.
func (t *Twalk) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -1213,7 +1180,6 @@ func (t *Twalk) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Twalkgetattr) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -1270,7 +1236,6 @@ func (t *Tumknod) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Tlconnect) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -1303,7 +1268,6 @@ func (t *Tchannel) handle(cs *connState) message {
return newErr(err)
}
- // Lookup the given channel.
ch := cs.lookupChannel(t.ID)
if ch == nil {
return newErr(syscall.ENOSYS)
diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go
index 25530adca..d3090535a 100644
--- a/pkg/p9/p9.go
+++ b/pkg/p9/p9.go
@@ -32,21 +32,21 @@ import (
type OpenFlags uint32
const (
- // ReadOnly is a Topen and Tcreate flag indicating read-only mode.
+ // ReadOnly is a Tlopen and Tlcreate flag indicating read-only mode.
ReadOnly OpenFlags = 0
- // WriteOnly is a Topen and Tcreate flag indicating write-only mode.
+ // WriteOnly is a Tlopen and Tlcreate flag indicating write-only mode.
WriteOnly OpenFlags = 1
- // ReadWrite is a Topen flag indicates read-write mode.
+ // ReadWrite is a Tlopen flag indicates read-write mode.
ReadWrite OpenFlags = 2
// OpenFlagsModeMask is a mask of valid OpenFlags mode bits.
OpenFlagsModeMask OpenFlags = 3
- // OpenFlagsIgnoreMask is a list of OpenFlags mode bits that are ignored for Tlopen.
- // Note that syscall.O_LARGEFILE is set to zero, use value from Linux fcntl.h.
- OpenFlagsIgnoreMask OpenFlags = syscall.O_DIRECTORY | syscall.O_NOATIME | 0100000
+ // OpenTruncate is a Tlopen flag indicating that the opened file should be
+ // truncated.
+ OpenTruncate OpenFlags = 01000
)
// ConnectFlags is the mode passed to Connect operations.
@@ -71,25 +71,32 @@ const (
// OSFlags converts a p9.OpenFlags to an int compatible with open(2).
func (o OpenFlags) OSFlags() int {
- return int(o & OpenFlagsModeMask)
+ // "flags contains Linux open(2) flags bits" - 9P2000.L
+ return int(o)
}
// String implements fmt.Stringer.
func (o OpenFlags) String() string {
- switch o {
+ var buf strings.Builder
+ switch mode := o & OpenFlagsModeMask; mode {
case ReadOnly:
- return "ReadOnly"
+ buf.WriteString("ReadOnly")
case WriteOnly:
- return "WriteOnly"
+ buf.WriteString("WriteOnly")
case ReadWrite:
- return "ReadWrite"
- case OpenFlagsModeMask:
- return "OpenFlagsModeMask"
- case OpenFlagsIgnoreMask:
- return "OpenFlagsIgnoreMask"
+ buf.WriteString("ReadWrite")
default:
- return "UNDEFINED"
+ fmt.Fprintf(&buf, "%#o", mode)
}
+ otherFlags := o &^ OpenFlagsModeMask
+ if otherFlags&OpenTruncate != 0 {
+ buf.WriteString("|OpenTruncate")
+ otherFlags &^= OpenTruncate
+ }
+ if otherFlags != 0 {
+ fmt.Fprintf(&buf, "|%#o", otherFlags)
+ }
+ return buf.String()
}
// Tag is a message tag.
@@ -814,7 +821,7 @@ func StatToAttr(s *syscall.Stat_t, req AttrMask) (Attr, AttrMask) {
attr.Mode = FileMode(s.Mode)
}
if req.NLink {
- attr.NLink = s.Nlink
+ attr.NLink = uint64(s.Nlink)
}
if req.UID {
attr.UID = UID(s.Uid)
diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go
index 8bbdb2488..6e758148d 100644
--- a/pkg/p9/p9test/client_test.go
+++ b/pkg/p9/p9test/client_test.go
@@ -1044,11 +1044,11 @@ func TestReaddir(t *testing.T) {
if _, err := f.Readdir(0, 1); err != syscall.EINVAL {
t.Errorf("readdir got %v, wanted EINVAL", err)
}
- if _, _, _, err := f.Open(p9.ReadWrite); err != syscall.EINVAL {
- t.Errorf("readdir got %v, wanted EINVAL", err)
+ if _, _, _, err := f.Open(p9.ReadWrite); err != syscall.EISDIR {
+ t.Errorf("readdir got %v, wanted EISDIR", err)
}
- if _, _, _, err := f.Open(p9.WriteOnly); err != syscall.EINVAL {
- t.Errorf("readdir got %v, wanted EINVAL", err)
+ if _, _, _, err := f.Open(p9.WriteOnly); err != syscall.EISDIR {
+ t.Errorf("readdir got %v, wanted EISDIR", err)
}
backend.EXPECT().Open(p9.ReadOnly).Times(1)
if _, _, _, err := f.Open(p9.ReadOnly); err != nil {
@@ -1065,75 +1065,93 @@ func TestReaddir(t *testing.T) {
func TestOpen(t *testing.T) {
type openTest struct {
name string
- mode p9.OpenFlags
+ flags p9.OpenFlags
err error
match func(p9.FileMode) bool
}
cases := []openTest{
{
- name: "invalid",
- mode: ^p9.OpenFlagsModeMask,
- err: syscall.EINVAL,
- match: func(p9.FileMode) bool { return true },
- },
- {
name: "not-openable-read-only",
- mode: p9.ReadOnly,
+ flags: p9.ReadOnly,
err: syscall.EINVAL,
match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) },
},
{
name: "not-openable-write-only",
- mode: p9.WriteOnly,
+ flags: p9.WriteOnly,
err: syscall.EINVAL,
match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) },
},
{
name: "not-openable-read-write",
- mode: p9.ReadWrite,
+ flags: p9.ReadWrite,
err: syscall.EINVAL,
match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) },
},
{
name: "directory-read-only",
- mode: p9.ReadOnly,
+ flags: p9.ReadOnly,
err: nil,
match: func(mode p9.FileMode) bool { return mode.IsDir() },
},
{
name: "directory-read-write",
- mode: p9.ReadWrite,
- err: syscall.EINVAL,
+ flags: p9.ReadWrite,
+ err: syscall.EISDIR,
match: func(mode p9.FileMode) bool { return mode.IsDir() },
},
{
name: "directory-write-only",
- mode: p9.WriteOnly,
- err: syscall.EINVAL,
+ flags: p9.WriteOnly,
+ err: syscall.EISDIR,
match: func(mode p9.FileMode) bool { return mode.IsDir() },
},
{
name: "read-only",
- mode: p9.ReadOnly,
+ flags: p9.ReadOnly,
err: nil,
match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) },
},
{
name: "write-only",
- mode: p9.WriteOnly,
+ flags: p9.WriteOnly,
err: nil,
match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() },
},
{
name: "read-write",
- mode: p9.ReadWrite,
+ flags: p9.ReadWrite,
+ err: nil,
+ match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() },
+ },
+ {
+ name: "directory-read-only-truncate",
+ flags: p9.ReadOnly | p9.OpenTruncate,
+ err: syscall.EISDIR,
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ },
+ {
+ name: "read-only-truncate",
+ flags: p9.ReadOnly | p9.OpenTruncate,
+ err: nil,
+ match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() },
+ },
+ {
+ name: "write-only-truncate",
+ flags: p9.WriteOnly | p9.OpenTruncate,
+ err: nil,
+ match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() },
+ },
+ {
+ name: "read-write-truncate",
+ flags: p9.ReadWrite | p9.OpenTruncate,
err: nil,
match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() },
},
}
- // Open(mode OpenFlags) (*fd.FD, QID, uint32, error)
+ // Open(flags OpenFlags) (*fd.FD, QID, uint32, error)
// - only works on Regular, NamedPipe, BLockDevice, CharacterDevice
// - returning a file works as expected
for name := range newTypeMap(nil) {
@@ -1171,25 +1189,25 @@ func TestOpen(t *testing.T) {
// Attempt the given open.
if tc.err != nil {
// We expect an error, just test and return.
- if _, _, _, err := f.Open(tc.mode); err != tc.err {
- t.Fatalf("open with mode %v got %v, want %v", tc.mode, err, tc.err)
+ if _, _, _, err := f.Open(tc.flags); err != tc.err {
+ t.Fatalf("open with flags %v got %v, want %v", tc.flags, err, tc.err)
}
return
}
// Run an FD test, since we expect success.
fdTest(t, func(send *fd.FD) *fd.FD {
- backend.EXPECT().Open(tc.mode).Return(send, p9.QID{}, uint32(0), nil).Times(1)
- recv, _, _, err := f.Open(tc.mode)
+ backend.EXPECT().Open(tc.flags).Return(send, p9.QID{}, uint32(0), nil).Times(1)
+ recv, _, _, err := f.Open(tc.flags)
if err != tc.err {
- t.Fatalf("open with mode %v got %v, want %v", tc.mode, err, tc.err)
+ t.Fatalf("open with flags %v got %v, want %v", tc.flags, err, tc.err)
}
return recv
})
// If the open was successful, attempt another one.
- if _, _, _, err := f.Open(tc.mode); err != syscall.EINVAL {
- t.Errorf("second open with mode %v got %v, want EINVAL", tc.mode, err)
+ if _, _, _, err := f.Open(tc.flags); err != syscall.EINVAL {
+ t.Errorf("second open with flags %v got %v, want EINVAL", tc.flags, err)
}
// Ensure that all illegal operations fail.
diff --git a/pkg/p9/server.go b/pkg/p9/server.go
index e717e6161..40b8fa023 100644
--- a/pkg/p9/server.go
+++ b/pkg/p9/server.go
@@ -453,7 +453,11 @@ func (cs *connState) initializeChannels() (err error) {
go func() { // S/R-SAFE: Server side.
defer cs.channelWg.Done()
if err := res.service(cs); err != nil {
- log.Warningf("p9.channel.service: %v", err)
+ // Don't log flipcall.ShutdownErrors, which we expect to be
+ // returned during server shutdown.
+ if _, ok := err.(flipcall.ShutdownError); !ok {
+ log.Warningf("p9.channel.service: %v", err)
+ }
}
}()
}
diff --git a/pkg/p9/version.go b/pkg/p9/version.go
index f1ffdd23a..36a694c58 100644
--- a/pkg/p9/version.go
+++ b/pkg/p9/version.go
@@ -26,7 +26,7 @@ const (
//
// Clients are expected to start requesting this version number and
// to continuously decrement it until a Tversion request succeeds.
- highestSupportedVersion uint32 = 8
+ highestSupportedVersion uint32 = 9
// lowestSupportedVersion is the lowest supported version X in a
// version string of the format 9P2000.L.Google.X.
@@ -155,3 +155,9 @@ func versionSupportsTallocate(v uint32) bool {
func versionSupportsFlipcall(v uint32) bool {
return v >= 8
}
+
+// VersionSupportsOpenTruncateFlag returns true if version v supports
+// passing the OpenTruncate flag to Tlopen.
+func VersionSupportsOpenTruncateFlag(v uint32) bool {
+ return v >= 9
+}
diff --git a/pkg/procid/procid_amd64.s b/pkg/procid/procid_amd64.s
index 30ec8e6e2..38cea9be3 100644
--- a/pkg/procid/procid_amd64.s
+++ b/pkg/procid/procid_amd64.s
@@ -14,7 +14,7 @@
// +build amd64
// +build go1.8
-// +build !go1.14
+// +build !go1.15
#include "textflag.h"
diff --git a/pkg/procid/procid_arm64.s b/pkg/procid/procid_arm64.s
index e340d9f98..4f4b70fef 100644
--- a/pkg/procid/procid_arm64.s
+++ b/pkg/procid/procid_arm64.s
@@ -14,7 +14,7 @@
// +build arm64
// +build go1.8
-// +build !go1.14
+// +build !go1.15
#include "textflag.h"
diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go
index c7503f2cc..fc36efa23 100644
--- a/pkg/seccomp/seccomp.go
+++ b/pkg/seccomp/seccomp.go
@@ -199,6 +199,10 @@ func ruleViolationLabel(ruleSetIdx int, sysno uintptr, idx int) string {
return fmt.Sprintf("ruleViolation_%v_%v_%v", ruleSetIdx, sysno, idx)
}
+func ruleLabel(ruleSetIdx int, sysno uintptr, idx int, name string) string {
+ return fmt.Sprintf("rule_%v_%v_%v_%v", ruleSetIdx, sysno, idx, name)
+}
+
func checkArgsLabel(sysno uintptr) string {
return fmt.Sprintf("checkArgs_%v", sysno)
}
@@ -223,6 +227,19 @@ func addSyscallArgsCheck(p *bpf.ProgramBuilder, rules []Rule, action linux.BPFAc
p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArgHigh(i))
p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
labelled = true
+ case GreaterThan:
+ labelGood := fmt.Sprintf("gt%v", i)
+ high, low := uint32(a>>32), uint32(a)
+ // assert arg_high < high
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArgHigh(i))
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jge|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
+ // arg_high > high
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
+ // arg_low < low
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArgLow(i))
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jgt|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
+ p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
+ labelled = true
default:
return fmt.Errorf("unknown syscall rule type: %v", reflect.TypeOf(a))
}
diff --git a/pkg/seccomp/seccomp_rules.go b/pkg/seccomp/seccomp_rules.go
index 29eec8db1..84c841d7f 100644
--- a/pkg/seccomp/seccomp_rules.go
+++ b/pkg/seccomp/seccomp_rules.go
@@ -49,6 +49,9 @@ func (a AllowAny) String() (s string) {
// AllowValue specifies a value that needs to be strictly matched.
type AllowValue uintptr
+// GreaterThan specifies a value that needs to be strictly smaller.
+type GreaterThan uintptr
+
func (a AllowValue) String() (s string) {
return fmt.Sprintf("%#x ", uintptr(a))
}
diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go
index 353686ed3..abbee7051 100644
--- a/pkg/seccomp/seccomp_test.go
+++ b/pkg/seccomp/seccomp_test.go
@@ -340,6 +340,54 @@ func TestBasic(t *testing.T) {
},
},
},
+ {
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ GreaterThan(0xf),
+ GreaterThan(0xabcd000d),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ specs: []spec{
+ {
+ desc: "GreaterThan: Syscall argument allowed",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xffffffff}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "GreaterThan: Syscall argument disallowed (equal)",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf, 0xffffffff}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "Syscall argument disallowed (smaller)",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x0, 0xffffffff}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "GreaterThan2: Syscall argument allowed",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xfbcd000d}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "GreaterThan2: Syscall argument disallowed (equal)",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xabcd000d}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "GreaterThan2: Syscall argument disallowed (smaller)",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x10, 0xa000ffff}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
} {
instrs, err := BuildProgram(test.ruleSets, test.defaultAction)
if err != nil {
diff --git a/pkg/seccomp/seccomp_test_victim.go b/pkg/seccomp/seccomp_test_victim.go
index 48413f1fb..da6b9eaaf 100644
--- a/pkg/seccomp/seccomp_test_victim.go
+++ b/pkg/seccomp/seccomp_test_victim.go
@@ -38,7 +38,7 @@ func main() {
syscall.SYS_CLONE: {},
syscall.SYS_CLOSE: {},
syscall.SYS_DUP: {},
- syscall.SYS_DUP2: {},
+ syscall.SYS_DUP3: {},
syscall.SYS_EPOLL_CREATE1: {},
syscall.SYS_EPOLL_CTL: {},
syscall.SYS_EPOLL_WAIT: {},
diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go
index 9e7db8b30..67daa6c24 100644
--- a/pkg/sentry/arch/arch_amd64.go
+++ b/pkg/sentry/arch/arch_amd64.go
@@ -305,7 +305,7 @@ func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) {
buf := binary.Marshal(nil, usermem.ByteOrder, c.ptraceGetRegs())
return c.Native(uintptr(usermem.ByteOrder.Uint64(buf[addr:]))), nil
}
- // TODO(b/34088053): debug registers
+ // Note: x86 debug registers are missing.
return c.Native(0), nil
}
@@ -320,6 +320,6 @@ func (c *context64) PtracePokeUser(addr, data uintptr) error {
_, err := c.PtraceSetRegs(bytes.NewBuffer(buf))
return err
}
- // TODO(b/34088053): debug registers
+ // Note: x86 debug registers are missing.
return nil
}
diff --git a/pkg/sentry/context/context.go b/pkg/sentry/context/context.go
index dfd62cbdb..23e009ef3 100644
--- a/pkg/sentry/context/context.go
+++ b/pkg/sentry/context/context.go
@@ -12,10 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package context defines the sentry's Context type.
+// Package context defines an internal context type.
+//
+// The given Context conforms to the standard Go context, but mandates
+// additional methods that are specific to the kernel internals. Note however,
+// that the Context described by this package carries additional constraints
+// regarding concurrent access and retaining beyond the scope of a call.
+//
+// See the Context type for complete details.
package context
import (
+ "context"
+ "time"
+
"gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/log"
)
@@ -59,6 +69,7 @@ func ThreadGroupIDFromContext(ctx Context) (tgid int32, ok bool) {
type Context interface {
log.Logger
amutex.Sleeper
+ context.Context
// UninterruptibleSleepStart indicates the beginning of an uninterruptible
// sleep state (equivalent to Linux's TASK_UNINTERRUPTIBLE). If deactivate
@@ -72,19 +83,36 @@ type Context interface {
// AddressSpace is activated. Normally activate is the same value as the
// deactivate parameter passed to UninterruptibleSleepStart.
UninterruptibleSleepFinish(activate bool)
+}
+
+// NoopSleeper is a noop implementation of amutex.Sleeper and UninterruptibleSleep
+// methods for anonymous embedding in other types that do not implement sleeps.
+type NoopSleeper struct {
+ amutex.NoopSleeper
+}
+
+// UninterruptibleSleepStart does nothing.
+func (NoopSleeper) UninterruptibleSleepStart(bool) {}
+
+// UninterruptibleSleepFinish does nothing.
+func (NoopSleeper) UninterruptibleSleepFinish(bool) {}
+
+// Deadline returns zero values, meaning no deadline.
+func (NoopSleeper) Deadline() (time.Time, bool) {
+ return time.Time{}, false
+}
+
+// Done returns nil.
+func (NoopSleeper) Done() <-chan struct{} {
+ return nil
+}
- // Value returns the value associated with this Context for key, or nil if
- // no value is associated with key. Successive calls to Value with the same
- // key returns the same result.
- //
- // A key identifies a specific value in a Context. Functions that wish to
- // retrieve values from Context typically allocate a key in a global
- // variable then use that key as the argument to Context.Value. A key can
- // be any type that supports equality; packages should define keys as an
- // unexported type to avoid collisions.
- Value(key interface{}) interface{}
+// Err returns nil.
+func (NoopSleeper) Err() error {
+ return nil
}
+// logContext implements basic logging.
type logContext struct {
log.Logger
NoopSleeper
@@ -95,19 +123,6 @@ func (logContext) Value(key interface{}) interface{} {
return nil
}
-// NoopSleeper is a noop implementation of amutex.Sleeper and
-// Context.UninterruptibleSleep* methods for anonymous embedding in other types
-// that do not want to notify kernel.Task about sleeps.
-type NoopSleeper struct {
- amutex.NoopSleeper
-}
-
-// UninterruptibleSleepStart does nothing.
-func (NoopSleeper) UninterruptibleSleepStart(bool) {}
-
-// UninterruptibleSleepFinish does nothing.
-func (NoopSleeper) UninterruptibleSleepFinish(bool) {}
-
// bgContext is the context returned by context.Background.
var bgContext = &logContext{Logger: log.Log()}
diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go
index 1f78d54a2..e1f2fea60 100644
--- a/pkg/sentry/control/pprof.go
+++ b/pkg/sentry/control/pprof.go
@@ -22,6 +22,7 @@ import (
"sync"
"gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/urpc"
)
@@ -56,6 +57,9 @@ type Profile struct {
// traceFile is the current execution trace output file.
traceFile *fd.FD
+
+ // Kernel is the kernel under profile.
+ Kernel *kernel.Kernel
}
// StartCPUProfile is an RPC stub which starts recording the CPU profile in a
@@ -147,6 +151,9 @@ func (p *Profile) StartTrace(o *ProfileOpts, _ *struct{}) error {
return err
}
+ // Ensure all trace contexts are registered.
+ p.Kernel.RebuildTraceContexts()
+
p.traceFile = output
return nil
}
@@ -158,9 +165,15 @@ func (p *Profile) StopTrace(_, _ *struct{}) error {
defer p.mu.Unlock()
if p.traceFile == nil {
- return errors.New("Execution tracing not start")
+ return errors.New("Execution tracing not started")
}
+ // Similarly to the case above, if tasks have not ended traces, we will
+ // lose information. Thus we need to rebuild the tasks in order to have
+ // complete information. This will not lose information if multiple
+ // traces are overlapping.
+ p.Kernel.RebuildTraceContexts()
+
trace.Stop()
p.traceFile.Close()
p.traceFile = nil
diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go
index c35faeb4c..ced51c66c 100644
--- a/pkg/sentry/control/proc.go
+++ b/pkg/sentry/control/proc.go
@@ -268,14 +268,17 @@ func (proc *Proc) Ps(args *PsArgs, out *string) error {
}
// Process contains information about a single process in a Sandbox.
-// TODO(b/117881927): Implement TTY field.
type Process struct {
UID auth.KUID `json:"uid"`
PID kernel.ThreadID `json:"pid"`
// Parent PID
- PPID kernel.ThreadID `json:"ppid"`
+ PPID kernel.ThreadID `json:"ppid"`
+ Threads []kernel.ThreadID `json:"threads"`
// Processor utilization
C int32 `json:"c"`
+ // TTY name of the process. Will be of the form "pts/N" if there is a
+ // TTY, or "?" if there is not.
+ TTY string `json:"tty"`
// Start time
STime string `json:"stime"`
// CPU time
@@ -285,18 +288,19 @@ type Process struct {
}
// ProcessListToTable prints a table with the following format:
-// UID PID PPID C STIME TIME CMD
-// 0 1 0 0 14:04 505262ns tail
+// UID PID PPID C TTY STIME TIME CMD
+// 0 1 0 0 pty/4 14:04 505262ns tail
func ProcessListToTable(pl []*Process) string {
var buf bytes.Buffer
tw := tabwriter.NewWriter(&buf, 10, 1, 3, ' ', 0)
- fmt.Fprint(tw, "UID\tPID\tPPID\tC\tSTIME\tTIME\tCMD")
+ fmt.Fprint(tw, "UID\tPID\tPPID\tC\tTTY\tSTIME\tTIME\tCMD")
for _, d := range pl {
- fmt.Fprintf(tw, "\n%d\t%d\t%d\t%d\t%s\t%s\t%s",
+ fmt.Fprintf(tw, "\n%d\t%d\t%d\t%d\t%s\t%s\t%s\t%s",
d.UID,
d.PID,
d.PPID,
d.C,
+ d.TTY,
d.STime,
d.Time,
d.Cmd)
@@ -307,7 +311,7 @@ func ProcessListToTable(pl []*Process) string {
// ProcessListToJSON will return the JSON representation of ps.
func ProcessListToJSON(pl []*Process) (string, error) {
- b, err := json.Marshal(pl)
+ b, err := json.MarshalIndent(pl, "", " ")
if err != nil {
return "", fmt.Errorf("couldn't marshal process list %v: %v", pl, err)
}
@@ -334,7 +338,9 @@ func Processes(k *kernel.Kernel, containerID string, out *[]*Process) error {
ts := k.TaskSet()
now := k.RealtimeClock().Now()
for _, tg := range ts.Root.ThreadGroups() {
- pid := tg.PIDNamespace().IDOfThreadGroup(tg)
+ pidns := tg.PIDNamespace()
+ pid := pidns.IDOfThreadGroup(tg)
+
// If tg has already been reaped ignore it.
if pid == 0 {
continue
@@ -345,16 +351,19 @@ func Processes(k *kernel.Kernel, containerID string, out *[]*Process) error {
ppid := kernel.ThreadID(0)
if p := tg.Leader().Parent(); p != nil {
- ppid = p.PIDNamespace().IDOfThreadGroup(p.ThreadGroup())
+ ppid = pidns.IDOfThreadGroup(p.ThreadGroup())
}
+ threads := tg.MemberIDs(pidns)
*out = append(*out, &Process{
- UID: tg.Leader().Credentials().EffectiveKUID,
- PID: pid,
- PPID: ppid,
- STime: formatStartTime(now, tg.Leader().StartTime()),
- C: percentCPU(tg.CPUStats(), tg.Leader().StartTime(), now),
- Time: tg.CPUStats().SysTime.String(),
- Cmd: tg.Leader().Name(),
+ UID: tg.Leader().Credentials().EffectiveKUID,
+ PID: pid,
+ PPID: ppid,
+ Threads: threads,
+ STime: formatStartTime(now, tg.Leader().StartTime()),
+ C: percentCPU(tg.CPUStats(), tg.Leader().StartTime(), now),
+ Time: tg.CPUStats().SysTime.String(),
+ Cmd: tg.Leader().Name(),
+ TTY: ttyName(tg.TTY()),
})
}
sort.Slice(*out, func(i, j int) bool { return (*out)[i].PID < (*out)[j].PID })
@@ -395,3 +404,10 @@ func percentCPU(stats usage.CPUStats, startTime, now ktime.Time) int32 {
}
return int32(percentCPU)
}
+
+func ttyName(tty *kernel.TTY) string {
+ if tty == nil {
+ return "?"
+ }
+ return fmt.Sprintf("pts/%d", tty.Index)
+}
diff --git a/pkg/sentry/control/proc_test.go b/pkg/sentry/control/proc_test.go
index d8ada2694..0a88459b2 100644
--- a/pkg/sentry/control/proc_test.go
+++ b/pkg/sentry/control/proc_test.go
@@ -34,7 +34,7 @@ func TestProcessListTable(t *testing.T) {
}{
{
pl: []*Process{},
- expected: "UID PID PPID C STIME TIME CMD",
+ expected: "UID PID PPID C TTY STIME TIME CMD",
},
{
pl: []*Process{
@@ -43,6 +43,7 @@ func TestProcessListTable(t *testing.T) {
PID: 0,
PPID: 0,
C: 0,
+ TTY: "?",
STime: "0",
Time: "0",
Cmd: "zero",
@@ -52,14 +53,15 @@ func TestProcessListTable(t *testing.T) {
PID: 1,
PPID: 1,
C: 1,
+ TTY: "pts/4",
STime: "1",
Time: "1",
Cmd: "one",
},
},
- expected: `UID PID PPID C STIME TIME CMD
-0 0 0 0 0 0 zero
-1 1 1 1 1 1 one`,
+ expected: `UID PID PPID C TTY STIME TIME CMD
+0 0 0 0 ? 0 0 zero
+1 1 1 1 pts/4 1 1 one`,
},
}
diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD
index 378602cc9..c035ffff7 100644
--- a/pkg/sentry/fs/BUILD
+++ b/pkg/sentry/fs/BUILD
@@ -68,9 +68,9 @@ go_library(
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
"//pkg/state",
+ "//pkg/syncutil",
"//pkg/syserror",
"//pkg/waiter",
- "//third_party/gvsync",
],
)
diff --git a/pkg/sentry/fs/fdpipe/pipe_opener_test.go b/pkg/sentry/fs/fdpipe/pipe_opener_test.go
index 8e4d839e1..577445148 100644
--- a/pkg/sentry/fs/fdpipe/pipe_opener_test.go
+++ b/pkg/sentry/fs/fdpipe/pipe_opener_test.go
@@ -25,6 +25,7 @@ import (
"time"
"github.com/google/uuid"
+
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/context/contexttest"
diff --git a/pkg/sentry/fs/flags.go b/pkg/sentry/fs/flags.go
index 0fab876a9..4338ae1fa 100644
--- a/pkg/sentry/fs/flags.go
+++ b/pkg/sentry/fs/flags.go
@@ -64,6 +64,10 @@ type FileFlags struct {
// NonSeekable indicates that file.offset isn't used.
NonSeekable bool
+
+ // Truncate indicates that the file should be truncated before opened.
+ // This is only applicable if the file is regular.
+ Truncate bool
}
// SettableFileFlags is a subset of FileFlags above that can be changed
@@ -118,6 +122,9 @@ func (f FileFlags) ToLinux() (mask uint) {
if f.LargeFile {
mask |= linux.O_LARGEFILE
}
+ if f.Truncate {
+ mask |= linux.O_TRUNC
+ }
switch {
case f.Read && f.Write:
diff --git a/pkg/sentry/fs/gofer/file_state.go b/pkg/sentry/fs/gofer/file_state.go
index c2fbb4be9..bb8312849 100644
--- a/pkg/sentry/fs/gofer/file_state.go
+++ b/pkg/sentry/fs/gofer/file_state.go
@@ -28,8 +28,14 @@ func (f *fileOperations) afterLoad() {
// Manually load the open handles.
var err error
+
+ // The file may have been opened with Truncate, but we don't
+ // want to re-open it with Truncate or we will lose data.
+ flags := f.flags
+ flags.Truncate = false
+
// TODO(b/38173783): Context is not plumbed to save/restore.
- f.handles, err = f.inodeOperations.fileState.getHandles(context.Background(), f.flags, f.inodeOperations.cachingInodeOps)
+ f.handles, err = f.inodeOperations.fileState.getHandles(context.Background(), flags, f.inodeOperations.cachingInodeOps)
if err != nil {
return fmt.Errorf("failed to re-open handle: %v", err)
}
diff --git a/pkg/sentry/fs/gofer/handles.go b/pkg/sentry/fs/gofer/handles.go
index 39c8ec33d..b86c49b39 100644
--- a/pkg/sentry/fs/gofer/handles.go
+++ b/pkg/sentry/fs/gofer/handles.go
@@ -64,7 +64,7 @@ func (h *handles) DecRef() {
})
}
-func newHandles(ctx context.Context, file contextFile, flags fs.FileFlags) (*handles, error) {
+func newHandles(ctx context.Context, client *p9.Client, file contextFile, flags fs.FileFlags) (*handles, error) {
_, newFile, err := file.walk(ctx, nil)
if err != nil {
return nil, err
@@ -81,6 +81,9 @@ func newHandles(ctx context.Context, file contextFile, flags fs.FileFlags) (*han
default:
panic("impossible fs.FileFlags")
}
+ if flags.Truncate && p9.VersionSupportsOpenTruncateFlag(client.Version()) {
+ p9flags |= p9.OpenTruncate
+ }
hostFile, _, _, err := newFile.open(ctx, p9flags)
if err != nil {
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index 99910388f..91263ebdc 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -180,7 +180,7 @@ func (i *inodeFileState) setSharedHandlesLocked(flags fs.FileFlags, h *handles)
// given flags.
func (i *inodeFileState) getHandles(ctx context.Context, flags fs.FileFlags, cache *fsutil.CachingInodeOperations) (*handles, error) {
if !i.canShareHandles() {
- return newHandles(ctx, i.file, flags)
+ return newHandles(ctx, i.s.client, i.file, flags)
}
i.handlesMu.Lock()
@@ -201,19 +201,25 @@ func (i *inodeFileState) getHandles(ctx context.Context, flags fs.FileFlags, cac
// whether previously open read handle was recreated. Host mappings must be
// invalidated if so.
func (i *inodeFileState) getHandlesLocked(ctx context.Context, flags fs.FileFlags) (*handles, bool, error) {
- // Do we already have usable shared handles?
- if flags.Write {
+ // Check if we are able to use cached handles.
+ if flags.Truncate && p9.VersionSupportsOpenTruncateFlag(i.s.client.Version()) {
+ // If we are truncating (and the gofer supports it), then we
+ // always need a new handle. Don't return one from the cache.
+ } else if flags.Write {
if i.writeHandles != nil && (i.writeHandlesRW || !flags.Read) {
+ // File is opened for writing, and we have cached write
+ // handles that we can use.
i.writeHandles.IncRef()
return i.writeHandles, false, nil
}
} else if i.readHandles != nil {
+ // File is opened for reading and we have cached handles.
i.readHandles.IncRef()
return i.readHandles, false, nil
}
- // No; get new handles and cache them for future sharing.
- h, err := newHandles(ctx, i.file, flags)
+ // Get new handles and cache them for future sharing.
+ h, err := newHandles(ctx, i.s.client, i.file, flags)
if err != nil {
return nil, false, err
}
@@ -239,7 +245,7 @@ func (i *inodeFileState) recreateReadHandles(ctx context.Context, writer *handle
if !flags.Read {
// Writer can't be used for read, must create a new handle.
var err error
- h, err = newHandles(ctx, i.file, fs.FileFlags{Read: true})
+ h, err = newHandles(ctx, i.s.client, i.file, fs.FileFlags{Read: true})
if err != nil {
return err
}
@@ -268,7 +274,7 @@ func (i *inodeFileState) recreateReadHandles(ctx context.Context, writer *handle
// operations on the old will see the new data. Then, make the new handle take
// ownereship of the old FD and mark the old readHandle to not close the FD
// when done.
- if err := syscall.Dup2(h.Host.FD(), i.readHandles.Host.FD()); err != nil {
+ if err := syscall.Dup3(h.Host.FD(), i.readHandles.Host.FD(), 0); err != nil {
return err
}
diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go
index 8c17603f8..c09f3b71c 100644
--- a/pkg/sentry/fs/gofer/path.go
+++ b/pkg/sentry/fs/gofer/path.go
@@ -234,6 +234,8 @@ func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string,
if err != nil {
return nil, err
}
+ // We're not going to use newFile after return.
+ defer newFile.close(ctx)
// Stabilize the endpoint map while creation is in progress.
unlock := i.session().endpoints.lock()
@@ -254,7 +256,6 @@ func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string,
// Get the attributes of the file to create inode key.
qid, mask, attr, err := getattr(ctx, newFile)
if err != nil {
- newFile.close(ctx)
return nil, err
}
@@ -270,7 +271,6 @@ func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string,
// cloned and re-opened multiple times after creation.
_, unopened, err := i.fileState.file.walk(ctx, []string{name})
if err != nil {
- newFile.close(ctx)
return nil, err
}
diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go
index 0da608548..4e358a46a 100644
--- a/pkg/sentry/fs/gofer/session.go
+++ b/pkg/sentry/fs/gofer/session.go
@@ -143,9 +143,9 @@ type session struct {
// socket files. This allows unix domain sockets to be used with paths that
// belong to a gofer.
//
- // TODO(b/77154739): there are few possible races with someone stat'ing the
- // file and another deleting it concurrently, where the file will not be
- // reported as socket file.
+ // TODO(gvisor.dev/issue/1200): there are few possible races with someone
+ // stat'ing the file and another deleting it concurrently, where the file
+ // will not be reported as socket file.
endpoints *endpointMaps `state:"wait"`
}
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
index 1cbed07ae..23daeb528 100644
--- a/pkg/sentry/fs/host/BUILD
+++ b/pkg/sentry/fs/host/BUILD
@@ -21,6 +21,8 @@ go_library(
"socket_unsafe.go",
"tty.go",
"util.go",
+ "util_amd64_unsafe.go",
+ "util_arm64_unsafe.go",
"util_unsafe.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/fs/host",
diff --git a/pkg/sentry/fs/host/util.go b/pkg/sentry/fs/host/util.go
index bad61a9a1..e37e687c6 100644
--- a/pkg/sentry/fs/host/util.go
+++ b/pkg/sentry/fs/host/util.go
@@ -155,7 +155,7 @@ func unstableAttr(mo *superOperations, s *syscall.Stat_t) fs.UnstableAttr {
AccessTime: ktime.FromUnix(s.Atim.Sec, s.Atim.Nsec),
ModificationTime: ktime.FromUnix(s.Mtim.Sec, s.Mtim.Nsec),
StatusChangeTime: ktime.FromUnix(s.Ctim.Sec, s.Ctim.Nsec),
- Links: s.Nlink,
+ Links: uint64(s.Nlink),
}
}
diff --git a/pkg/sentry/fs/host/util_amd64_unsafe.go b/pkg/sentry/fs/host/util_amd64_unsafe.go
new file mode 100644
index 000000000..66da6e9f5
--- /dev/null
+++ b/pkg/sentry/fs/host/util_amd64_unsafe.go
@@ -0,0 +1,41 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package host
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+func fstatat(fd int, name string, flags int) (syscall.Stat_t, error) {
+ var stat syscall.Stat_t
+ namePtr, err := syscall.BytePtrFromString(name)
+ if err != nil {
+ return stat, err
+ }
+ _, _, errno := syscall.Syscall6(
+ syscall.SYS_NEWFSTATAT,
+ uintptr(fd),
+ uintptr(unsafe.Pointer(namePtr)),
+ uintptr(unsafe.Pointer(&stat)),
+ uintptr(flags),
+ 0, 0)
+ if errno != 0 {
+ return stat, errno
+ }
+ return stat, nil
+}
diff --git a/pkg/sentry/fs/host/util_arm64_unsafe.go b/pkg/sentry/fs/host/util_arm64_unsafe.go
new file mode 100644
index 000000000..e8cb94aeb
--- /dev/null
+++ b/pkg/sentry/fs/host/util_arm64_unsafe.go
@@ -0,0 +1,41 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package host
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+func fstatat(fd int, name string, flags int) (syscall.Stat_t, error) {
+ var stat syscall.Stat_t
+ namePtr, err := syscall.BytePtrFromString(name)
+ if err != nil {
+ return stat, err
+ }
+ _, _, errno := syscall.Syscall6(
+ syscall.SYS_FSTATAT,
+ uintptr(fd),
+ uintptr(unsafe.Pointer(namePtr)),
+ uintptr(unsafe.Pointer(&stat)),
+ uintptr(flags),
+ 0, 0)
+ if errno != 0 {
+ return stat, errno
+ }
+ return stat, nil
+}
diff --git a/pkg/sentry/fs/host/util_unsafe.go b/pkg/sentry/fs/host/util_unsafe.go
index 2b76f1065..3ab36b088 100644
--- a/pkg/sentry/fs/host/util_unsafe.go
+++ b/pkg/sentry/fs/host/util_unsafe.go
@@ -116,22 +116,3 @@ func setTimestamps(fd int, ts fs.TimeSpec) error {
}
return nil
}
-
-func fstatat(fd int, name string, flags int) (syscall.Stat_t, error) {
- var stat syscall.Stat_t
- namePtr, err := syscall.BytePtrFromString(name)
- if err != nil {
- return stat, err
- }
- _, _, errno := syscall.Syscall6(
- syscall.SYS_NEWFSTATAT,
- uintptr(fd),
- uintptr(unsafe.Pointer(namePtr)),
- uintptr(unsafe.Pointer(&stat)),
- uintptr(flags),
- 0, 0)
- if errno != 0 {
- return stat, errno
- }
- return stat, nil
-}
diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go
index f4ddfa406..91e2fde2f 100644
--- a/pkg/sentry/fs/inode.go
+++ b/pkg/sentry/fs/inode.go
@@ -270,6 +270,14 @@ func (i *Inode) Getxattr(name string) (string, error) {
return i.InodeOperations.Getxattr(i, name)
}
+// Setxattr calls i.InodeOperations.Setxattr with i as the Inode.
+func (i *Inode) Setxattr(name, value string) error {
+ if i.overlay != nil {
+ return overlaySetxattr(i.overlay, name, value)
+ }
+ return i.InodeOperations.Setxattr(i, name, value)
+}
+
// Listxattr calls i.InodeOperations.Listxattr with i as the Inode.
func (i *Inode) Listxattr() (map[string]struct{}, error) {
if i.overlay != nil {
@@ -344,6 +352,10 @@ func (i *Inode) SetTimestamps(ctx context.Context, d *Dirent, ts TimeSpec) error
// Truncate calls i.InodeOperations.Truncate with i as the Inode.
func (i *Inode) Truncate(ctx context.Context, d *Dirent, size int64) error {
+ if IsDir(i.StableAttr) {
+ return syserror.EISDIR
+ }
+
if i.overlay != nil {
return overlayTruncate(ctx, i.overlay, d, size)
}
diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go
index 5a388dad1..13d11e001 100644
--- a/pkg/sentry/fs/inode_overlay.go
+++ b/pkg/sentry/fs/inode_overlay.go
@@ -436,7 +436,7 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena
}
func overlayBind(ctx context.Context, o *overlayEntry, parent *Dirent, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) {
- if err := copyUp(ctx, parent); err != nil {
+ if err := copyUpLockedForRename(ctx, parent); err != nil {
return nil, err
}
@@ -462,7 +462,9 @@ func overlayBind(ctx context.Context, o *overlayEntry, parent *Dirent, name stri
inode.DecRef()
return nil, err
}
- return NewDirent(ctx, newOverlayInode(ctx, entry, inode.MountSource), name), nil
+ // Use the parent's MountSource, since that corresponds to the overlay,
+ // and not the upper filesystem.
+ return NewDirent(ctx, newOverlayInode(ctx, entry, parent.Inode.MountSource), name), nil
}
func overlayBoundEndpoint(o *overlayEntry, path string) transport.BoundEndpoint {
@@ -550,6 +552,11 @@ func overlayGetxattr(o *overlayEntry, name string) (string, error) {
return s, err
}
+// TODO(b/146028302): Support setxattr for overlayfs.
+func overlaySetxattr(o *overlayEntry, name, value string) error {
+ return syserror.EOPNOTSUPP
+}
+
func overlayListxattr(o *overlayEntry) (map[string]struct{}, error) {
o.copyMu.RLock()
defer o.copyMu.RUnlock()
diff --git a/pkg/sentry/fs/overlay.go b/pkg/sentry/fs/overlay.go
index 1d3ff39e0..25573e986 100644
--- a/pkg/sentry/fs/overlay.go
+++ b/pkg/sentry/fs/overlay.go
@@ -23,8 +23,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/syncutil"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/third_party/gvsync"
)
// The virtual filesystem implements an overlay configuration. For a high-level
@@ -199,7 +199,7 @@ type overlayEntry struct {
upper *Inode
// dirCacheMu protects dirCache.
- dirCacheMu gvsync.DowngradableRWMutex `state:"nosave"`
+ dirCacheMu syncutil.DowngradableRWMutex `state:"nosave"`
// dirCache is cache of DentAttrs from upper and lower Inodes.
dirCache *SortedDentryMap
diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD
index 497475db5..75cbb0622 100644
--- a/pkg/sentry/fs/proc/BUILD
+++ b/pkg/sentry/fs/proc/BUILD
@@ -52,6 +52,7 @@ go_library(
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
"//pkg/syserror",
+ "//pkg/tcpip/header",
"//pkg/waiter",
],
)
diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go
index f70239449..402919924 100644
--- a/pkg/sentry/fs/proc/net.go
+++ b/pkg/sentry/fs/proc/net.go
@@ -18,6 +18,7 @@ import (
"bytes"
"fmt"
"io"
+ "reflect"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -33,6 +34,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/unix"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
// newNet creates a new proc net entry.
@@ -40,7 +43,8 @@ func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSo
var contents map[string]*fs.Inode
if s := p.k.NetworkStack(); s != nil {
contents = map[string]*fs.Inode{
- "dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc),
+ "dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc),
+ "snmp": seqfile.NewSeqFileInode(ctx, &netSnmp{s: s}, msrc),
// The following files are simple stubs until they are
// implemented in netstack, if the file contains a
@@ -57,7 +61,7 @@ func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSo
// (ClockGetres returns 1ns resolution).
"psched": newStaticProcInode(ctx, msrc, []byte(fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond)))),
"ptype": newStaticProcInode(ctx, msrc, []byte("Type Device Function")),
- "route": newStaticProcInode(ctx, msrc, []byte("Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT")),
+ "route": seqfile.NewSeqFileInode(ctx, &netRoute{s: s}, msrc),
"tcp": seqfile.NewSeqFileInode(ctx, &netTCP{k: k}, msrc),
"udp": seqfile.NewSeqFileInode(ctx, &netUDP{k: k}, msrc),
"unix": seqfile.NewSeqFileInode(ctx, &netUnix{k: k}, msrc),
@@ -195,6 +199,193 @@ func (n *netDev) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
return data, 0
}
+// netSnmp implements seqfile.SeqSource for /proc/net/snmp.
+//
+// +stateify savable
+type netSnmp struct {
+ s inet.Stack
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (n *netSnmp) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+type snmpLine struct {
+ prefix string
+ header string
+}
+
+var snmp = []snmpLine{
+ {
+ prefix: "Ip",
+ header: "Forwarding DefaultTTL InReceives InHdrErrors InAddrErrors ForwDatagrams InUnknownProtos InDiscards InDelivers OutRequests OutDiscards OutNoRoutes ReasmTimeout ReasmReqds ReasmOKs ReasmFails FragOKs FragFails FragCreates",
+ },
+ {
+ prefix: "Icmp",
+ header: "InMsgs InErrors InCsumErrors InDestUnreachs InTimeExcds InParmProbs InSrcQuenchs InRedirects InEchos InEchoReps InTimestamps InTimestampReps InAddrMasks InAddrMaskReps OutMsgs OutErrors OutDestUnreachs OutTimeExcds OutParmProbs OutSrcQuenchs OutRedirects OutEchos OutEchoReps OutTimestamps OutTimestampReps OutAddrMasks OutAddrMaskReps",
+ },
+ {
+ prefix: "IcmpMsg",
+ },
+ {
+ prefix: "Tcp",
+ header: "RtoAlgorithm RtoMin RtoMax MaxConn ActiveOpens PassiveOpens AttemptFails EstabResets CurrEstab InSegs OutSegs RetransSegs InErrs OutRsts InCsumErrors",
+ },
+ {
+ prefix: "Udp",
+ header: "InDatagrams NoPorts InErrors OutDatagrams RcvbufErrors SndbufErrors InCsumErrors IgnoredMulti",
+ },
+ {
+ prefix: "UdpLite",
+ header: "InDatagrams NoPorts InErrors OutDatagrams RcvbufErrors SndbufErrors InCsumErrors IgnoredMulti",
+ },
+}
+
+func toSlice(a interface{}) []uint64 {
+ v := reflect.Indirect(reflect.ValueOf(a))
+ return v.Slice(0, v.Len()).Interface().([]uint64)
+}
+
+func sprintSlice(s []uint64) string {
+ if len(s) == 0 {
+ return ""
+ }
+ r := fmt.Sprint(s)
+ return r[1 : len(r)-1] // Remove "[]" introduced by fmt of slice.
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. See Linux's
+// net/core/net-procfs.c:dev_seq_show.
+func (n *netSnmp) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ contents := make([]string, 0, len(snmp)*2)
+ types := []interface{}{
+ &inet.StatSNMPIP{},
+ &inet.StatSNMPICMP{},
+ nil, // TODO(gvisor.dev/issue/628): Support IcmpMsg stats.
+ &inet.StatSNMPTCP{},
+ &inet.StatSNMPUDP{},
+ &inet.StatSNMPUDPLite{},
+ }
+ for i, stat := range types {
+ line := snmp[i]
+ if stat == nil {
+ contents = append(
+ contents,
+ fmt.Sprintf("%s:\n", line.prefix),
+ fmt.Sprintf("%s:\n", line.prefix),
+ )
+ continue
+ }
+ if err := n.s.Statistics(stat, line.prefix); err != nil {
+ if err == syserror.EOPNOTSUPP {
+ log.Infof("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err)
+ } else {
+ log.Warningf("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err)
+ }
+ }
+ var values string
+ if line.prefix == "Tcp" {
+ tcp := stat.(*inet.StatSNMPTCP)
+ // "Tcp" needs special processing because MaxConn is signed. RFC 2012.
+ values = fmt.Sprintf("%s %d %s", sprintSlice(tcp[:3]), int64(tcp[3]), sprintSlice(tcp[4:]))
+ } else {
+ values = sprintSlice(toSlice(stat))
+ }
+ contents = append(
+ contents,
+ fmt.Sprintf("%s: %s\n", line.prefix, line.header),
+ fmt.Sprintf("%s: %s\n", line.prefix, values),
+ )
+ }
+
+ data := make([]seqfile.SeqData, 0, len(snmp)*2)
+ for _, l := range contents {
+ data = append(data, seqfile.SeqData{Buf: []byte(l), Handle: (*netSnmp)(nil)})
+ }
+
+ return data, 0
+}
+
+// netRoute implements seqfile.SeqSource for /proc/net/route.
+//
+// +stateify savable
+type netRoute struct {
+ s inet.Stack
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (n *netRoute) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+// See Linux's net/ipv4/fib_trie.c:fib_route_seq_show.
+func (n *netRoute) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ if h != nil {
+ return nil, 0
+ }
+
+ interfaces := n.s.Interfaces()
+ contents := []string{"Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT"}
+ for _, rt := range n.s.RouteTable() {
+ // /proc/net/route only includes ipv4 routes.
+ if rt.Family != linux.AF_INET {
+ continue
+ }
+
+ // /proc/net/route does not include broadcast or multicast routes.
+ if rt.Type == linux.RTN_BROADCAST || rt.Type == linux.RTN_MULTICAST {
+ continue
+ }
+
+ iface, ok := interfaces[rt.OutputInterface]
+ if !ok || iface.Name == "lo" {
+ continue
+ }
+
+ var (
+ gw uint32
+ prefix uint32
+ flags = linux.RTF_UP
+ )
+ if len(rt.GatewayAddr) == header.IPv4AddressSize {
+ flags |= linux.RTF_GATEWAY
+ gw = usermem.ByteOrder.Uint32(rt.GatewayAddr)
+ }
+ if len(rt.DstAddr) == header.IPv4AddressSize {
+ prefix = usermem.ByteOrder.Uint32(rt.DstAddr)
+ }
+ l := fmt.Sprintf(
+ "%s\t%08X\t%08X\t%04X\t%d\t%d\t%d\t%08X\t%d\t%d\t%d",
+ iface.Name,
+ prefix,
+ gw,
+ flags,
+ 0, // RefCnt.
+ 0, // Use.
+ 0, // Metric.
+ (uint32(1)<<rt.DstLen)-1,
+ 0, // MTU.
+ 0, // Window.
+ 0, // RTT.
+ )
+ contents = append(contents, l)
+ }
+
+ var data []seqfile.SeqData
+ for _, l := range contents {
+ l = fmt.Sprintf("%-127s\n", l)
+ data = append(data, seqfile.SeqData{Buf: []byte(l), Handle: (*netRoute)(nil)})
+ }
+
+ return data, 0
+}
+
// netUnix implements seqfile.SeqSource for /proc/net/unix.
//
// +stateify savable
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go
index f3b63dfc2..bd93f83fa 100644
--- a/pkg/sentry/fs/proc/sys_net.go
+++ b/pkg/sentry/fs/proc/sys_net.go
@@ -64,7 +64,7 @@ var _ fs.InodeOperations = (*tcpMemInode)(nil)
func newTCPMemInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack, dir tcpMemDir) *fs.Inode {
tm := &tcpMemInode{
- SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC),
+ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC),
s: s,
dir: dir,
}
@@ -77,6 +77,11 @@ func newTCPMemInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack, dir
return fs.NewInode(ctx, tm, msrc, sattr)
}
+// Truncate implements fs.InodeOperations.Truncate.
+func (tcpMemInode) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
// GetFile implements fs.InodeOperations.GetFile.
func (m *tcpMemInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
flags.Pread = true
@@ -168,14 +173,15 @@ func writeSize(dirType tcpMemDir, s inet.Stack, size inet.TCPBufferSize) error {
// +stateify savable
type tcpSack struct {
+ fsutil.SimpleFileInode
+
stack inet.Stack `state:"wait"`
enabled *bool
- fsutil.SimpleFileInode
}
func newTCPSackInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode {
ts := &tcpSack{
- SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC),
+ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC),
stack: s,
}
sattr := fs.StableAttr{
@@ -187,6 +193,11 @@ func newTCPSackInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *f
return fs.NewInode(ctx, ts, msrc, sattr)
}
+// Truncate implements fs.InodeOperations.Truncate.
+func (tcpSack) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
// GetFile implements fs.InodeOperations.GetFile.
func (s *tcpSack) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
flags.Pread = true
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index 87184ec67..9bf4b4527 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -67,29 +67,28 @@ type taskDir struct {
var _ fs.InodeOperations = (*taskDir)(nil)
// newTaskDir creates a new proc task entry.
-func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, showSubtasks bool) *fs.Inode {
+func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode {
contents := map[string]*fs.Inode{
- "auxv": newAuxvec(t, msrc),
- "cmdline": newExecArgInode(t, msrc, cmdlineExecArg),
- "comm": newComm(t, msrc),
- "environ": newExecArgInode(t, msrc, environExecArg),
- "exe": newExe(t, msrc),
- "fd": newFdDir(t, msrc),
- "fdinfo": newFdInfoDir(t, msrc),
- "gid_map": newGIDMap(t, msrc),
- // FIXME(b/123511468): create the correct io file for threads.
- "io": newIO(t, msrc),
+ "auxv": newAuxvec(t, msrc),
+ "cmdline": newExecArgInode(t, msrc, cmdlineExecArg),
+ "comm": newComm(t, msrc),
+ "environ": newExecArgInode(t, msrc, environExecArg),
+ "exe": newExe(t, msrc),
+ "fd": newFdDir(t, msrc),
+ "fdinfo": newFdInfoDir(t, msrc),
+ "gid_map": newGIDMap(t, msrc),
+ "io": newIO(t, msrc, isThreadGroup),
"maps": newMaps(t, msrc),
"mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc),
"mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc),
"ns": newNamespaceDir(t, msrc),
"smaps": newSmaps(t, msrc),
- "stat": newTaskStat(t, msrc, showSubtasks, p.pidns),
+ "stat": newTaskStat(t, msrc, isThreadGroup, p.pidns),
"statm": newStatm(t, msrc),
"status": newStatus(t, msrc, p.pidns),
"uid_map": newUIDMap(t, msrc),
}
- if showSubtasks {
+ if isThreadGroup {
contents["task"] = p.newSubtasks(t, msrc)
}
if len(p.cgroupControllers) > 0 {
@@ -605,6 +604,10 @@ func (s *statusData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) (
fmt.Fprintf(&buf, "CapEff:\t%016x\n", creds.EffectiveCaps)
fmt.Fprintf(&buf, "CapBnd:\t%016x\n", creds.BoundingCaps)
fmt.Fprintf(&buf, "Seccomp:\t%d\n", s.t.SeccompMode())
+ // We unconditionally report a single NUMA node. See
+ // pkg/sentry/syscalls/linux/sys_mempolicy.go.
+ fmt.Fprintf(&buf, "Mems_allowed:\t1\n")
+ fmt.Fprintf(&buf, "Mems_allowed_list:\t0\n")
return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*statusData)(nil)}}, 0
}
@@ -619,8 +622,11 @@ type ioData struct {
ioUsage
}
-func newIO(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
- return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t.ThreadGroup()}), msrc, fs.SpecialFile, t)
+func newIO(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode {
+ if isThreadGroup {
+ return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t.ThreadGroup()}), msrc, fs.SpecialFile, t)
+ }
+ return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t}), msrc, fs.SpecialFile, t)
}
// NeedsUpdate returns whether the generation is old or not.
@@ -639,7 +645,7 @@ func (i *ioData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
io.Accumulate(i.IOUsage())
var buf bytes.Buffer
- fmt.Fprintf(&buf, "char: %d\n", io.CharsRead)
+ fmt.Fprintf(&buf, "rchar: %d\n", io.CharsRead)
fmt.Fprintf(&buf, "wchar: %d\n", io.CharsWritten)
fmt.Fprintf(&buf, "syscr: %d\n", io.ReadSyscalls)
fmt.Fprintf(&buf, "syscw: %d\n", io.WriteSyscalls)
diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go
index 19b7557d5..6b07f6bf2 100644
--- a/pkg/sentry/fs/tty/master.go
+++ b/pkg/sentry/fs/tty/master.go
@@ -76,6 +76,11 @@ func newMasterInode(ctx context.Context, d *dirInodeOperations, owner fs.FileOwn
func (mi *masterInodeOperations) Release(ctx context.Context) {
}
+// Truncate implements fs.InodeOperations.Truncate.
+func (*masterInodeOperations) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
// GetFile implements fs.InodeOperations.GetFile.
//
// It allocates a new terminal.
diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/slave.go
index 944c4ada1..2a51e6bab 100644
--- a/pkg/sentry/fs/tty/slave.go
+++ b/pkg/sentry/fs/tty/slave.go
@@ -72,6 +72,11 @@ func (si *slaveInodeOperations) Release(ctx context.Context) {
si.t.DecRef()
}
+// Truncate implements fs.InodeOperations.Truncate.
+func (slaveInodeOperations) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
// GetFile implements fs.InodeOperations.GetFile.
//
// This may race with destruction of the terminal. If the terminal is gone, it
diff --git a/pkg/sentry/fs/tty/terminal.go b/pkg/sentry/fs/tty/terminal.go
index ff8138820..917f90cc0 100644
--- a/pkg/sentry/fs/tty/terminal.go
+++ b/pkg/sentry/fs/tty/terminal.go
@@ -53,8 +53,8 @@ func newTerminal(ctx context.Context, d *dirInodeOperations, n uint32) *Terminal
d: d,
n: n,
ld: newLineDiscipline(termios),
- masterKTTY: &kernel.TTY{},
- slaveKTTY: &kernel.TTY{},
+ masterKTTY: &kernel.TTY{Index: n},
+ slaveKTTY: &kernel.TTY{Index: n},
}
t.EnableLeakCheck("tty.Terminal")
return &t
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
index 7ccff8b0d..880b7bcd3 100644
--- a/pkg/sentry/fsimpl/ext/BUILD
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -38,6 +38,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/binary",
"//pkg/fd",
+ "//pkg/fspath",
"//pkg/log",
"//pkg/sentry/arch",
"//pkg/sentry/context",
diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
index 10a8083a0..177ce2cb9 100644
--- a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
+++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
@@ -50,7 +50,7 @@ func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesys
// Create VFS.
vfsObj := vfs.New()
vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.NewFilesystemOptions{InternalData: int(f.Fd())})
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())})
if err != nil {
f.Close()
return nil, nil, nil, nil, err
@@ -81,7 +81,11 @@ func mount(b *testing.B, imagePath string, vfsfs *vfs.VirtualFilesystem, pop *vf
ctx := contexttest.Context(b)
creds := auth.CredentialsFromContext(ctx)
- if err := vfsfs.NewMount(ctx, creds, imagePath, pop, "extfs", &vfs.NewFilesystemOptions{InternalData: int(f.Fd())}); err != nil {
+ if err := vfsfs.MountAt(ctx, creds, imagePath, pop, "extfs", &vfs.MountOptions{
+ GetFilesystemOptions: vfs.GetFilesystemOptions{
+ InternalData: int(f.Fd()),
+ },
+ }); err != nil {
b.Fatalf("failed to mount tmpfs submount: %v", err)
}
return func() {
diff --git a/pkg/sentry/fsimpl/ext/block_map_file.go b/pkg/sentry/fsimpl/ext/block_map_file.go
index cea89bcd9..a2d8c3ad6 100644
--- a/pkg/sentry/fsimpl/ext/block_map_file.go
+++ b/pkg/sentry/fsimpl/ext/block_map_file.go
@@ -154,7 +154,7 @@ func (f *blockMapFile) read(curPhyBlk uint32, relFileOff uint64, height uint, ds
toRead = len(dst)
}
- n, _ := f.regFile.inode.dev.ReadAt(dst[:toRead], curPhyBlkOff+int64(relFileOff))
+ n, _ := f.regFile.inode.fs.dev.ReadAt(dst[:toRead], curPhyBlkOff+int64(relFileOff))
if n < toRead {
return n, syserror.EIO
}
@@ -174,7 +174,7 @@ func (f *blockMapFile) read(curPhyBlk uint32, relFileOff uint64, height uint, ds
curChildOff := relFileOff % childCov
for i := startIdx; i < endIdx; i++ {
var childPhyBlk uint32
- err := readFromDisk(f.regFile.inode.dev, curPhyBlkOff+int64(i*4), &childPhyBlk)
+ err := readFromDisk(f.regFile.inode.fs.dev, curPhyBlkOff+int64(i*4), &childPhyBlk)
if err != nil {
return read, err
}
diff --git a/pkg/sentry/fsimpl/ext/block_map_test.go b/pkg/sentry/fsimpl/ext/block_map_test.go
index 213aa3919..181727ef7 100644
--- a/pkg/sentry/fsimpl/ext/block_map_test.go
+++ b/pkg/sentry/fsimpl/ext/block_map_test.go
@@ -87,12 +87,14 @@ func blockMapSetUp(t *testing.T) (*blockMapFile, []byte) {
mockDisk := make([]byte, mockBMDiskSize)
regFile := regularFile{
inode: inode{
+ fs: &filesystem{
+ dev: bytes.NewReader(mockDisk),
+ },
diskInode: &disklayout.InodeNew{
InodeOld: disklayout.InodeOld{
SizeLo: getMockBMFileFize(),
},
},
- dev: bytes.NewReader(mockDisk),
blkSize: uint64(mockBMBlkSize),
},
}
diff --git a/pkg/sentry/fsimpl/ext/dentry.go b/pkg/sentry/fsimpl/ext/dentry.go
index 054fb42b6..a080cb189 100644
--- a/pkg/sentry/fsimpl/ext/dentry.go
+++ b/pkg/sentry/fsimpl/ext/dentry.go
@@ -41,16 +41,18 @@ func newDentry(in *inode) *dentry {
}
// IncRef implements vfs.DentryImpl.IncRef.
-func (d *dentry) IncRef(vfsfs *vfs.Filesystem) {
+func (d *dentry) IncRef() {
d.inode.incRef()
}
// TryIncRef implements vfs.DentryImpl.TryIncRef.
-func (d *dentry) TryIncRef(vfsfs *vfs.Filesystem) bool {
+func (d *dentry) TryIncRef() bool {
return d.inode.tryIncRef()
}
// DecRef implements vfs.DentryImpl.DecRef.
-func (d *dentry) DecRef(vfsfs *vfs.Filesystem) {
- d.inode.decRef(vfsfs.Impl().(*filesystem))
+func (d *dentry) DecRef() {
+ // FIXME(b/134676337): filesystem.mu may not be locked as required by
+ // inode.decRef().
+ d.inode.decRef()
}
diff --git a/pkg/sentry/fsimpl/ext/ext.go b/pkg/sentry/fsimpl/ext/ext.go
index f10accafc..4b7d17dc6 100644
--- a/pkg/sentry/fsimpl/ext/ext.go
+++ b/pkg/sentry/fsimpl/ext/ext.go
@@ -40,14 +40,14 @@ var _ vfs.FilesystemType = (*FilesystemType)(nil)
// Currently there are two ways of mounting an ext(2/3/4) fs:
// 1. Specify a mount with our internal special MountType in the OCI spec.
// 2. Expose the device to the container and mount it from application layer.
-func getDeviceFd(source string, opts vfs.NewFilesystemOptions) (io.ReaderAt, error) {
+func getDeviceFd(source string, opts vfs.GetFilesystemOptions) (io.ReaderAt, error) {
if opts.InternalData == nil {
// User mount call.
// TODO(b/134676337): Open the device specified by `source` and return that.
panic("unimplemented")
}
- // NewFilesystem call originated from within the sentry.
+ // GetFilesystem call originated from within the sentry.
devFd, ok := opts.InternalData.(int)
if !ok {
return nil, errors.New("internal data for ext fs must be an int containing the file descriptor to device")
@@ -91,8 +91,8 @@ func isCompatible(sb disklayout.SuperBlock) bool {
return true
}
-// NewFilesystem implements vfs.FilesystemType.NewFilesystem.
-func (FilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts vfs.NewFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
// TODO(b/134676337): Ensure that the user is mounting readonly. If not,
// EACCESS should be returned according to mount(2). Filesystem independent
// flags (like readonly) are currently not available in pkg/sentry/vfs.
@@ -103,7 +103,7 @@ func (FilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials
}
fs := filesystem{dev: dev, inodeCache: make(map[uint32]*inode)}
- fs.vfsfs.Init(&fs)
+ fs.vfsfs.Init(vfsObj, &fs)
fs.sb, err = readSuperBlock(dev)
if err != nil {
return nil, nil, err
diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go
index 1aa2bd6a4..e9f756732 100644
--- a/pkg/sentry/fsimpl/ext/ext_test.go
+++ b/pkg/sentry/fsimpl/ext/ext_test.go
@@ -66,7 +66,7 @@ func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesys
// Create VFS.
vfsObj := vfs.New()
vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.NewFilesystemOptions{InternalData: int(f.Fd())})
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())})
if err != nil {
f.Close()
return nil, nil, nil, nil, err
@@ -147,55 +147,54 @@ func TestSeek(t *testing.T) {
t.Fatalf("vfsfs.OpenAt failed: %v", err)
}
- if n, err := fd.Impl().Seek(ctx, 0, linux.SEEK_SET); n != 0 || err != nil {
+ if n, err := fd.Seek(ctx, 0, linux.SEEK_SET); n != 0 || err != nil {
t.Errorf("expected seek position 0, got %d and error %v", n, err)
}
- stat, err := fd.Impl().Stat(ctx, vfs.StatOptions{})
+ stat, err := fd.Stat(ctx, vfs.StatOptions{})
if err != nil {
t.Errorf("fd.stat failed for file %s in image %s: %v", test.path, test.image, err)
}
// We should be able to seek beyond the end of file.
size := int64(stat.Size)
- if n, err := fd.Impl().Seek(ctx, size, linux.SEEK_SET); n != size || err != nil {
+ if n, err := fd.Seek(ctx, size, linux.SEEK_SET); n != size || err != nil {
t.Errorf("expected seek position %d, got %d and error %v", size, n, err)
}
// EINVAL should be returned if the resulting offset is negative.
- if _, err := fd.Impl().Seek(ctx, -1, linux.SEEK_SET); err != syserror.EINVAL {
+ if _, err := fd.Seek(ctx, -1, linux.SEEK_SET); err != syserror.EINVAL {
t.Errorf("expected error EINVAL but got %v", err)
}
- if n, err := fd.Impl().Seek(ctx, 3, linux.SEEK_CUR); n != size+3 || err != nil {
+ if n, err := fd.Seek(ctx, 3, linux.SEEK_CUR); n != size+3 || err != nil {
t.Errorf("expected seek position %d, got %d and error %v", size+3, n, err)
}
// Make sure negative offsets work with SEEK_CUR.
- if n, err := fd.Impl().Seek(ctx, -2, linux.SEEK_CUR); n != size+1 || err != nil {
+ if n, err := fd.Seek(ctx, -2, linux.SEEK_CUR); n != size+1 || err != nil {
t.Errorf("expected seek position %d, got %d and error %v", size+1, n, err)
}
// EINVAL should be returned if the resulting offset is negative.
- if _, err := fd.Impl().Seek(ctx, -(size + 2), linux.SEEK_CUR); err != syserror.EINVAL {
+ if _, err := fd.Seek(ctx, -(size + 2), linux.SEEK_CUR); err != syserror.EINVAL {
t.Errorf("expected error EINVAL but got %v", err)
}
// Make sure SEEK_END works with regular files.
- switch fd.Impl().(type) {
- case *regularFileFD:
+ if _, ok := fd.Impl().(*regularFileFD); ok {
// Seek back to 0.
- if n, err := fd.Impl().Seek(ctx, -size, linux.SEEK_END); n != 0 || err != nil {
+ if n, err := fd.Seek(ctx, -size, linux.SEEK_END); n != 0 || err != nil {
t.Errorf("expected seek position %d, got %d and error %v", 0, n, err)
}
// Seek forward beyond EOF.
- if n, err := fd.Impl().Seek(ctx, 1, linux.SEEK_END); n != size+1 || err != nil {
+ if n, err := fd.Seek(ctx, 1, linux.SEEK_END); n != size+1 || err != nil {
t.Errorf("expected seek position %d, got %d and error %v", size+1, n, err)
}
// EINVAL should be returned if the resulting offset is negative.
- if _, err := fd.Impl().Seek(ctx, -(size + 1), linux.SEEK_END); err != syserror.EINVAL {
+ if _, err := fd.Seek(ctx, -(size + 1), linux.SEEK_END); err != syserror.EINVAL {
t.Errorf("expected error EINVAL but got %v", err)
}
}
@@ -456,7 +455,7 @@ func TestRead(t *testing.T) {
want := make([]byte, 1)
for {
n, err := f.Read(want)
- fd.Impl().Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{})
+ fd.Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{})
if diff := cmp.Diff(got, want); diff != "" {
t.Errorf("file data mismatch (-want +got):\n%s", diff)
@@ -464,7 +463,7 @@ func TestRead(t *testing.T) {
// Make sure there is no more file data left after getting EOF.
if n == 0 || err == io.EOF {
- if n, _ := fd.Impl().Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}); n != 0 {
+ if n, _ := fd.Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}); n != 0 {
t.Errorf("extra unexpected file data in file %s in image %s", test.absPath, test.image)
}
@@ -509,27 +508,27 @@ func TestIterDirents(t *testing.T) {
}
wantDirents := []vfs.Dirent{
- vfs.Dirent{
+ {
Name: ".",
Type: linux.DT_DIR,
},
- vfs.Dirent{
+ {
Name: "..",
Type: linux.DT_DIR,
},
- vfs.Dirent{
+ {
Name: "lost+found",
Type: linux.DT_DIR,
},
- vfs.Dirent{
+ {
Name: "file.txt",
Type: linux.DT_REG,
},
- vfs.Dirent{
+ {
Name: "bigfile.txt",
Type: linux.DT_REG,
},
- vfs.Dirent{
+ {
Name: "symlink.txt",
Type: linux.DT_LNK,
},
@@ -574,7 +573,7 @@ func TestIterDirents(t *testing.T) {
}
cb := &iterDirentsCb{}
- if err = fd.Impl().IterDirents(ctx, cb); err != nil {
+ if err = fd.IterDirents(ctx, cb); err != nil {
t.Fatalf("dir fd.IterDirents() failed: %v", err)
}
diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go
index 38b68a2d3..3d3ebaca6 100644
--- a/pkg/sentry/fsimpl/ext/extent_file.go
+++ b/pkg/sentry/fsimpl/ext/extent_file.go
@@ -99,7 +99,7 @@ func (f *extentFile) buildExtTree() error {
func (f *extentFile) buildExtTreeFromDisk(entry disklayout.ExtentEntry) (*disklayout.ExtentNode, error) {
var header disklayout.ExtentHeader
off := entry.PhysicalBlock() * f.regFile.inode.blkSize
- err := readFromDisk(f.regFile.inode.dev, int64(off), &header)
+ err := readFromDisk(f.regFile.inode.fs.dev, int64(off), &header)
if err != nil {
return nil, err
}
@@ -115,7 +115,7 @@ func (f *extentFile) buildExtTreeFromDisk(entry disklayout.ExtentEntry) (*diskla
curEntry = &disklayout.ExtentIdx{}
}
- err := readFromDisk(f.regFile.inode.dev, int64(off), curEntry)
+ err := readFromDisk(f.regFile.inode.fs.dev, int64(off), curEntry)
if err != nil {
return nil, err
}
@@ -229,7 +229,7 @@ func (f *extentFile) readFromExtent(ex *disklayout.Extent, off uint64, dst []byt
toRead = len(dst)
}
- n, _ := f.regFile.inode.dev.ReadAt(dst[:toRead], int64(readStart))
+ n, _ := f.regFile.inode.fs.dev.ReadAt(dst[:toRead], int64(readStart))
if n < toRead {
return n, syserror.EIO
}
diff --git a/pkg/sentry/fsimpl/ext/extent_test.go b/pkg/sentry/fsimpl/ext/extent_test.go
index 42d0a484b..a2382daa3 100644
--- a/pkg/sentry/fsimpl/ext/extent_test.go
+++ b/pkg/sentry/fsimpl/ext/extent_test.go
@@ -180,13 +180,15 @@ func extentTreeSetUp(t *testing.T, root *disklayout.ExtentNode) (*extentFile, []
mockExtentFile := &extentFile{
regFile: regularFile{
inode: inode{
+ fs: &filesystem{
+ dev: bytes.NewReader(mockDisk),
+ },
diskInode: &disklayout.InodeNew{
InodeOld: disklayout.InodeOld{
SizeLo: uint32(mockExtentBlkSize) * getNumPhyBlks(root),
},
},
blkSize: mockExtentBlkSize,
- dev: bytes.NewReader(mockDisk),
},
},
}
diff --git a/pkg/sentry/fsimpl/ext/file_description.go b/pkg/sentry/fsimpl/ext/file_description.go
index 4d18b28cb..841274daf 100644
--- a/pkg/sentry/fsimpl/ext/file_description.go
+++ b/pkg/sentry/fsimpl/ext/file_description.go
@@ -26,33 +26,14 @@ import (
type fileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
-
- // flags is the same as vfs.OpenOptions.Flags which are passed to
- // vfs.FilesystemImpl.OpenAt.
- // TODO(b/134676337): syscalls like read(2), write(2), fchmod(2), fchown(2),
- // fgetxattr(2), ioctl(2), mmap(2) should fail with EBADF if O_PATH is set.
- // Only close(2), fstat(2), fstatfs(2) should work.
- flags uint32
}
func (fd *fileDescription) filesystem() *filesystem {
- return fd.vfsfd.VirtualDentry().Mount().Filesystem().Impl().(*filesystem)
+ return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem)
}
func (fd *fileDescription) inode() *inode {
- return fd.vfsfd.VirtualDentry().Dentry().Impl().(*dentry).inode
-}
-
-// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags.
-func (fd *fileDescription) StatusFlags(ctx context.Context) (uint32, error) {
- return fd.flags, nil
-}
-
-// SetStatusFlags implements vfs.FileDescriptionImpl.SetStatusFlags.
-func (fd *fileDescription) SetStatusFlags(ctx context.Context, flags uint32) error {
- // None of the flags settable by fcntl(F_SETFL) are supported, so this is a
- // no-op.
- return nil
+ return fd.vfsfd.Dentry().Impl().(*dentry).inode
}
// Stat implements vfs.FileDescriptionImpl.Stat.
diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go
index 2d15e8aaf..d7e87979a 100644
--- a/pkg/sentry/fsimpl/ext/filesystem.go
+++ b/pkg/sentry/fsimpl/ext/filesystem.go
@@ -20,6 +20,7 @@ import (
"sync"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -441,3 +442,46 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
return syserror.EROFS
}
+
+// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
+func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) {
+ _, _, err := fs.walk(rp, false)
+ if err != nil {
+ return nil, err
+ }
+ return nil, syserror.ENOTSUP
+}
+
+// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
+func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) {
+ _, _, err := fs.walk(rp, false)
+ if err != nil {
+ return "", err
+ }
+ return "", syserror.ENOTSUP
+}
+
+// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
+func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+ _, _, err := fs.walk(rp, false)
+ if err != nil {
+ return err
+ }
+ return syserror.ENOTSUP
+}
+
+// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
+func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+ _, _, err := fs.walk(rp, false)
+ if err != nil {
+ return err
+ }
+ return syserror.ENOTSUP
+}
+
+// PrependPath implements vfs.FilesystemImpl.PrependPath.
+func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ return vfs.GenericPrependPath(vfsroot, vd, b)
+}
diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go
index e6c847a71..b2cc826c7 100644
--- a/pkg/sentry/fsimpl/ext/inode.go
+++ b/pkg/sentry/fsimpl/ext/inode.go
@@ -16,7 +16,6 @@ package ext
import (
"fmt"
- "io"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -42,13 +41,13 @@ type inode struct {
// refs is a reference count. refs is accessed using atomic memory operations.
refs int64
+ // fs is the containing filesystem.
+ fs *filesystem
+
// inodeNum is the inode number of this inode on disk. This is used to
// identify inodes within the ext filesystem.
inodeNum uint32
- // dev represents the underlying device. Same as filesystem.dev.
- dev io.ReaderAt
-
// blkSize is the fs data block size. Same as filesystem.sb.BlockSize().
blkSize uint64
@@ -81,10 +80,10 @@ func (in *inode) tryIncRef() bool {
// decRef decrements the inode ref count and releases the inode resources if
// the ref count hits 0.
//
-// Precondition: Must have locked fs.mu.
-func (in *inode) decRef(fs *filesystem) {
+// Precondition: Must have locked filesystem.mu.
+func (in *inode) decRef() {
if refs := atomic.AddInt64(&in.refs, -1); refs == 0 {
- delete(fs.inodeCache, in.inodeNum)
+ delete(in.fs.inodeCache, in.inodeNum)
} else if refs < 0 {
panic("ext.inode.decRef() called without holding a reference")
}
@@ -117,8 +116,8 @@ func newInode(fs *filesystem, inodeNum uint32) (*inode, error) {
// Build the inode based on its type.
inode := inode{
+ fs: fs,
inodeNum: inodeNum,
- dev: fs.dev,
blkSize: blkSize,
diskInode: diskInode,
}
@@ -154,11 +153,13 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v
if err := in.checkPermissions(rp.Credentials(), ats); err != nil {
return nil, err
}
+ mnt := rp.Mount()
switch in.impl.(type) {
case *regularFile:
var fd regularFileFD
- fd.flags = flags
- fd.vfsfd.Init(&fd, rp.Mount(), vfsd)
+ mnt.IncRef()
+ vfsd.IncRef()
+ fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
return &fd.vfsfd, nil
case *directory:
// Can't open directories writably. This check is not necessary for a read
@@ -167,8 +168,9 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v
return nil, syserror.EISDIR
}
var fd directoryFD
- fd.vfsfd.Init(&fd, rp.Mount(), vfsd)
- fd.flags = flags
+ mnt.IncRef()
+ vfsd.IncRef()
+ fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
return &fd.vfsfd, nil
case *symlink:
if flags&linux.O_PATH == 0 {
@@ -176,8 +178,9 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v
return nil, syserror.ELOOP
}
var fd symlinkFD
- fd.flags = flags
- fd.vfsfd.Init(&fd, rp.Mount(), vfsd)
+ mnt.IncRef()
+ vfsd.IncRef()
+ fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
return &fd.vfsfd, nil
default:
panic(fmt.Sprintf("unknown inode type: %T", in.impl))
diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD
new file mode 100644
index 000000000..52596c090
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/BUILD
@@ -0,0 +1,60 @@
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "slot_list",
+ out = "slot_list.go",
+ package = "kernfs",
+ prefix = "slot",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*slot",
+ "Linker": "*slot",
+ },
+)
+
+go_library(
+ name = "kernfs",
+ srcs = [
+ "dynamic_bytes_file.go",
+ "fd_impl_util.go",
+ "filesystem.go",
+ "inode_impl_util.go",
+ "kernfs.go",
+ "slot_list.go",
+ ],
+ importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs",
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/fspath",
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/sentry/context",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/usermem",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ ],
+)
+
+go_test(
+ name = "kernfs_test",
+ size = "small",
+ srcs = ["kernfs_test.go"],
+ deps = [
+ ":kernfs",
+ "//pkg/abi/linux",
+ "//pkg/sentry/context",
+ "//pkg/sentry/context/contexttest",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/usermem",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "@com_github_google_go-cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
new file mode 100644
index 000000000..51102ce48
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
@@ -0,0 +1,117 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// DynamicBytesFile implements kernfs.Inode and represents a read-only
+// file whose contents are backed by a vfs.DynamicBytesSource.
+//
+// Must be initialized with Init before first use.
+type DynamicBytesFile struct {
+ InodeAttrs
+ InodeNoopRefCount
+ InodeNotDirectory
+ InodeNotSymlink
+
+ data vfs.DynamicBytesSource
+}
+
+// Init intializes a dynamic bytes file.
+func (f *DynamicBytesFile) Init(creds *auth.Credentials, ino uint64, data vfs.DynamicBytesSource) {
+ f.InodeAttrs.Init(creds, ino, linux.ModeRegular|0444)
+ f.data = data
+}
+
+// Open implements Inode.Open.
+func (f *DynamicBytesFile) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+ fd := &DynamicBytesFD{}
+ fd.Init(rp.Mount(), vfsd, f.data, flags)
+ return &fd.vfsfd, nil
+}
+
+// SetStat implements Inode.SetStat.
+func (f *DynamicBytesFile) SetStat(*vfs.Filesystem, vfs.SetStatOptions) error {
+ // DynamicBytesFiles are immutable.
+ return syserror.EPERM
+}
+
+// DynamicBytesFD implements vfs.FileDescriptionImpl for an FD backed by a
+// DynamicBytesFile.
+//
+// Must be initialized with Init before first use.
+type DynamicBytesFD struct {
+ vfs.FileDescriptionDefaultImpl
+ vfs.DynamicBytesFileDescriptionImpl
+
+ vfsfd vfs.FileDescription
+ inode Inode
+}
+
+// Init initializes a DynamicBytesFD.
+func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *vfs.Dentry, data vfs.DynamicBytesSource, flags uint32) {
+ m.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref.
+ d.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref.
+ fd.inode = d.Impl().(*Dentry).inode
+ fd.SetDataSource(data)
+ fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{})
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *DynamicBytesFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ return fd.DynamicBytesFileDescriptionImpl.Seek(ctx, offset, whence)
+}
+
+// Read implmenets vfs.FileDescriptionImpl.Read.
+func (fd *DynamicBytesFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ return fd.DynamicBytesFileDescriptionImpl.Read(ctx, dst, opts)
+}
+
+// PRead implmenets vfs.FileDescriptionImpl.PRead.
+func (fd *DynamicBytesFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return fd.DynamicBytesFileDescriptionImpl.PRead(ctx, dst, offset, opts)
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *DynamicBytesFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return fd.FileDescriptionDefaultImpl.Write(ctx, src, opts)
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *DynamicBytesFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return fd.FileDescriptionDefaultImpl.PWrite(ctx, src, offset, opts)
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *DynamicBytesFD) Release() {}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *DynamicBytesFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ fs := fd.vfsfd.VirtualDentry().Mount().Filesystem()
+ return fd.inode.Stat(fs), nil
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *DynamicBytesFD) SetStat(context.Context, vfs.SetStatOptions) error {
+ // DynamicBytesFiles are immutable.
+ return syserror.EPERM
+}
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
new file mode 100644
index 000000000..bd402330f
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -0,0 +1,193 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// GenericDirectoryFD implements vfs.FileDescriptionImpl for a generic directory
+// inode that uses OrderChildren to track child nodes. GenericDirectoryFD is not
+// compatible with dynamic directories.
+//
+// Note that GenericDirectoryFD holds a lock over OrderedChildren while calling
+// IterDirents callback. The IterDirents callback therefore cannot hash or
+// unhash children, or recursively call IterDirents on the same underlying
+// inode.
+//
+// Must be initialize with Init before first use.
+type GenericDirectoryFD struct {
+ vfs.FileDescriptionDefaultImpl
+ vfs.DirectoryFileDescriptionDefaultImpl
+
+ vfsfd vfs.FileDescription
+ children *OrderedChildren
+ off int64
+}
+
+// Init initializes a GenericDirectoryFD.
+func (fd *GenericDirectoryFD) Init(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, flags uint32) {
+ m.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref.
+ d.IncRef() // DecRef in vfs.FileDescription.vd.DecRef on final ref.
+ fd.children = children
+ fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{})
+}
+
+// VFSFileDescription returns a pointer to the vfs.FileDescription representing
+// this object.
+func (fd *GenericDirectoryFD) VFSFileDescription() *vfs.FileDescription {
+ return &fd.vfsfd
+}
+
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *GenericDirectoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ return fd.FileDescriptionDefaultImpl.ConfigureMMap(ctx, opts)
+}
+
+// Read implmenets vfs.FileDescriptionImpl.Read.
+func (fd *GenericDirectoryFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ return fd.DirectoryFileDescriptionDefaultImpl.Read(ctx, dst, opts)
+}
+
+// PRead implmenets vfs.FileDescriptionImpl.PRead.
+func (fd *GenericDirectoryFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return fd.DirectoryFileDescriptionDefaultImpl.PRead(ctx, dst, offset, opts)
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *GenericDirectoryFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return fd.DirectoryFileDescriptionDefaultImpl.Write(ctx, src, opts)
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *GenericDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return fd.DirectoryFileDescriptionDefaultImpl.PWrite(ctx, src, offset, opts)
+}
+
+// Release implements vfs.FileDecriptionImpl.Release.
+func (fd *GenericDirectoryFD) Release() {}
+
+func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem {
+ return fd.vfsfd.VirtualDentry().Mount().Filesystem()
+}
+
+func (fd *GenericDirectoryFD) inode() Inode {
+ return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode
+}
+
+// IterDirents implements vfs.FileDecriptionImpl.IterDirents. IterDirents holds
+// o.mu when calling cb.
+func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
+ vfsFS := fd.filesystem()
+ fs := vfsFS.Impl().(*Filesystem)
+ vfsd := fd.vfsfd.VirtualDentry().Dentry()
+
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+
+ // Handle ".".
+ if fd.off == 0 {
+ stat := fd.inode().Stat(vfsFS)
+ dirent := vfs.Dirent{
+ Name: ".",
+ Type: linux.DT_DIR,
+ Ino: stat.Ino,
+ NextOff: 1,
+ }
+ if !cb.Handle(dirent) {
+ return nil
+ }
+ fd.off++
+ }
+
+ // Handle "..".
+ if fd.off == 1 {
+ parentInode := vfsd.ParentOrSelf().Impl().(*Dentry).inode
+ stat := parentInode.Stat(vfsFS)
+ dirent := vfs.Dirent{
+ Name: "..",
+ Type: linux.FileMode(stat.Mode).DirentType(),
+ Ino: stat.Ino,
+ NextOff: 2,
+ }
+ if !cb.Handle(dirent) {
+ return nil
+ }
+ fd.off++
+ }
+
+ // Handle static children.
+ fd.children.mu.RLock()
+ defer fd.children.mu.RUnlock()
+ // fd.off accounts for "." and "..", but fd.children do not track
+ // these.
+ childIdx := fd.off - 2
+ for it := fd.children.nthLocked(childIdx); it != nil; it = it.Next() {
+ inode := it.Dentry.Impl().(*Dentry).inode
+ stat := inode.Stat(vfsFS)
+ dirent := vfs.Dirent{
+ Name: it.Name,
+ Type: linux.FileMode(stat.Mode).DirentType(),
+ Ino: stat.Ino,
+ NextOff: fd.off + 1,
+ }
+ if !cb.Handle(dirent) {
+ return nil
+ }
+ fd.off++
+ }
+
+ return nil
+}
+
+// Seek implements vfs.FileDecriptionImpl.Seek.
+func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fs := fd.filesystem().Impl().(*Filesystem)
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+
+ switch whence {
+ case linux.SEEK_SET:
+ // Use offset as given.
+ case linux.SEEK_CUR:
+ offset += fd.off
+ default:
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ fd.off = offset
+ return offset, nil
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ fs := fd.filesystem()
+ inode := fd.inode()
+ return inode.Stat(fs), nil
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ fs := fd.filesystem()
+ inode := fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode
+ return inode.SetStat(fs, opts)
+}
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
new file mode 100644
index 000000000..3cbbe4b20
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -0,0 +1,743 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file implements vfs.FilesystemImpl for kernfs.
+
+package kernfs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// stepExistingLocked resolves rp.Component() in parent directory vfsd.
+//
+// stepExistingLocked is loosely analogous to fs/namei.c:walk_component().
+//
+// Preconditions: Filesystem.mu must be locked for at least reading. !rp.Done().
+//
+// Postcondition: Caller must call fs.processDeferredDecRefs*.
+func (fs *Filesystem) stepExistingLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry) (*vfs.Dentry, error) {
+ d := vfsd.Impl().(*Dentry)
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ // Directory searchable?
+ if err := d.inode.CheckPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+afterSymlink:
+ d.dirMu.Lock()
+ nextVFSD, err := rp.ResolveComponent(vfsd)
+ d.dirMu.Unlock()
+ if err != nil {
+ return nil, err
+ }
+ if nextVFSD != nil {
+ // Cached dentry exists, revalidate.
+ next := nextVFSD.Impl().(*Dentry)
+ if !next.inode.Valid(ctx) {
+ d.dirMu.Lock()
+ rp.VirtualFilesystem().ForceDeleteDentry(nextVFSD)
+ d.dirMu.Unlock()
+ fs.deferDecRef(nextVFSD) // Reference from Lookup.
+ nextVFSD = nil
+ }
+ }
+ if nextVFSD == nil {
+ // Dentry isn't cached; it either doesn't exist or failed
+ // revalidation. Attempt to resolve it via Lookup.
+ name := rp.Component()
+ var err error
+ nextVFSD, err = d.inode.Lookup(ctx, name)
+ // Reference on nextVFSD dropped by a corresponding Valid.
+ if err != nil {
+ return nil, err
+ }
+ d.InsertChild(name, nextVFSD)
+ }
+ next := nextVFSD.Impl().(*Dentry)
+
+ // Resolve any symlink at current path component.
+ if rp.ShouldFollowSymlink() && d.isSymlink() {
+ // TODO: VFS2 needs something extra for /proc/[pid]/fd/ "magic symlinks".
+ target, err := next.inode.Readlink(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if err := rp.HandleSymlink(target); err != nil {
+ return nil, err
+ }
+ goto afterSymlink
+
+ }
+ rp.Advance()
+ return nextVFSD, nil
+}
+
+// walkExistingLocked resolves rp to an existing file.
+//
+// walkExistingLocked is loosely analogous to Linux's
+// fs/namei.c:path_lookupat().
+//
+// Preconditions: Filesystem.mu must be locked for at least reading.
+//
+// Postconditions: Caller must call fs.processDeferredDecRefs*.
+func (fs *Filesystem) walkExistingLocked(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, Inode, error) {
+ vfsd := rp.Start()
+ for !rp.Done() {
+ var err error
+ vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+ d := vfsd.Impl().(*Dentry)
+ if rp.MustBeDir() && !d.isDir() {
+ return nil, nil, syserror.ENOTDIR
+ }
+ return vfsd, d.inode, nil
+}
+
+// walkParentDirLocked resolves all but the last path component of rp to an
+// existing directory. It does not check that the returned directory is
+// searchable by the provider of rp.
+//
+// walkParentDirLocked is loosely analogous to Linux's
+// fs/namei.c:path_parentat().
+//
+// Preconditions: Filesystem.mu must be locked for at least reading. !rp.Done().
+//
+// Postconditions: Caller must call fs.processDeferredDecRefs*.
+func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, Inode, error) {
+ vfsd := rp.Start()
+ for !rp.Final() {
+ var err error
+ vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+ d := vfsd.Impl().(*Dentry)
+ if !d.isDir() {
+ return nil, nil, syserror.ENOTDIR
+ }
+ return vfsd, d.inode, nil
+}
+
+// checkCreateLocked checks that a file named rp.Component() may be created in
+// directory parentVFSD, then returns rp.Component().
+//
+// Preconditions: Filesystem.mu must be locked for at least reading. parentInode
+// == parentVFSD.Impl().(*Dentry).Inode. isDir(parentInode) == true.
+func checkCreateLocked(rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInode Inode) (string, error) {
+ if err := parentInode.CheckPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return "", err
+ }
+ pc := rp.Component()
+ if pc == "." || pc == ".." {
+ return "", syserror.EEXIST
+ }
+ childVFSD, err := rp.ResolveChild(parentVFSD, pc)
+ if err != nil {
+ return "", err
+ }
+ if childVFSD != nil {
+ return "", syserror.EEXIST
+ }
+ if parentVFSD.IsDisowned() {
+ return "", syserror.ENOENT
+ }
+ return pc, nil
+}
+
+// checkDeleteLocked checks that the file represented by vfsd may be deleted.
+//
+// Preconditions: Filesystem.mu must be locked for at least reading.
+func checkDeleteLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry) error {
+ parentVFSD := vfsd.Parent()
+ if parentVFSD == nil {
+ return syserror.EBUSY
+ }
+ if parentVFSD.IsDisowned() {
+ return syserror.ENOENT
+ }
+ if err := parentVFSD.Impl().(*Dentry).inode.CheckPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ return nil
+}
+
+// checkRenameLocked checks that a rename operation may be performed on the
+// target dentry across the given set of parent directories. The target dentry
+// may be nil.
+//
+// Precondition: isDir(dstInode) == true.
+func checkRenameLocked(creds *auth.Credentials, src, dstDir *vfs.Dentry, dstInode Inode) error {
+ srcDir := src.Parent()
+ if srcDir == nil {
+ return syserror.EBUSY
+ }
+ if srcDir.IsDisowned() {
+ return syserror.ENOENT
+ }
+ if dstDir.IsDisowned() {
+ return syserror.ENOENT
+ }
+ // Check for creation permissions on dst dir.
+ if err := dstInode.CheckPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *Filesystem) Release() {
+}
+
+// Sync implements vfs.FilesystemImpl.Sync.
+func (fs *Filesystem) Sync(ctx context.Context) error {
+ // All filesystem state is in-memory.
+ return nil
+}
+
+// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
+func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
+ fs.mu.RLock()
+ defer fs.processDeferredDecRefs()
+ defer fs.mu.RUnlock()
+ vfsd, inode, err := fs.walkExistingLocked(ctx, rp)
+ if err != nil {
+ return nil, err
+ }
+
+ if opts.CheckSearchable {
+ d := vfsd.Impl().(*Dentry)
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := inode.CheckPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+ }
+ vfsd.IncRef() // Ownership transferred to caller.
+ return vfsd, nil
+}
+
+// LinkAt implements vfs.FilesystemImpl.LinkAt.
+func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
+ if rp.Done() {
+ return syserror.EEXIST
+ }
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
+ fs.processDeferredDecRefsLocked()
+ if err != nil {
+ return err
+ }
+ pc, err := checkCreateLocked(rp, parentVFSD, parentInode)
+ if err != nil {
+ return err
+ }
+ if rp.Mount() != vd.Mount() {
+ return syserror.EXDEV
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+
+ d := vd.Dentry().Impl().(*Dentry)
+ if d.isDir() {
+ return syserror.EPERM
+ }
+
+ child, err := parentInode.NewLink(ctx, pc, d.inode)
+ if err != nil {
+ return err
+ }
+ parentVFSD.Impl().(*Dentry).InsertChild(pc, child)
+ return nil
+}
+
+// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
+func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
+ if rp.Done() {
+ return syserror.EEXIST
+ }
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
+ fs.processDeferredDecRefsLocked()
+ if err != nil {
+ return err
+ }
+ pc, err := checkCreateLocked(rp, parentVFSD, parentInode)
+ if err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+ child, err := parentInode.NewDir(ctx, pc, opts)
+ if err != nil {
+ return err
+ }
+ parentVFSD.Impl().(*Dentry).InsertChild(pc, child)
+ return nil
+}
+
+// MknodAt implements vfs.FilesystemImpl.MknodAt.
+func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
+ if rp.Done() {
+ return syserror.EEXIST
+ }
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
+ fs.processDeferredDecRefsLocked()
+ if err != nil {
+ return err
+ }
+ pc, err := checkCreateLocked(rp, parentVFSD, parentInode)
+ if err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+ new, err := parentInode.NewNode(ctx, pc, opts)
+ if err != nil {
+ return err
+ }
+ parentVFSD.Impl().(*Dentry).InsertChild(pc, new)
+ return nil
+}
+
+// OpenAt implements vfs.FilesystemImpl.OpenAt.
+func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ // Filter out flags that are not supported by kernfs. O_DIRECTORY and
+ // O_NOFOLLOW have no effect here (they're handled by VFS by setting
+ // appropriate bits in rp), but are returned by
+ // FileDescriptionImpl.StatusFlags().
+ opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW
+ ats := vfs.AccessTypesForOpenFlags(opts.Flags)
+
+ // Do not create new file.
+ if opts.Flags&linux.O_CREAT == 0 {
+ fs.mu.RLock()
+ defer fs.processDeferredDecRefs()
+ defer fs.mu.RUnlock()
+ vfsd, inode, err := fs.walkExistingLocked(ctx, rp)
+ if err != nil {
+ return nil, err
+ }
+ if err := inode.CheckPermissions(rp.Credentials(), ats); err != nil {
+ return nil, err
+ }
+ return inode.Open(rp, vfsd, opts.Flags)
+ }
+
+ // May create new file.
+ mustCreate := opts.Flags&linux.O_EXCL != 0
+ vfsd := rp.Start()
+ inode := vfsd.Impl().(*Dentry).inode
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ if rp.Done() {
+ if rp.MustBeDir() {
+ return nil, syserror.EISDIR
+ }
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
+ if err := inode.CheckPermissions(rp.Credentials(), ats); err != nil {
+ return nil, err
+ }
+ return inode.Open(rp, vfsd, opts.Flags)
+ }
+afterTrailingSymlink:
+ parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
+ fs.processDeferredDecRefsLocked()
+ if err != nil {
+ return nil, err
+ }
+ // Check for search permission in the parent directory.
+ if err := parentInode.CheckPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+ // Reject attempts to open directories with O_CREAT.
+ if rp.MustBeDir() {
+ return nil, syserror.EISDIR
+ }
+ pc := rp.Component()
+ if pc == "." || pc == ".." {
+ return nil, syserror.EISDIR
+ }
+ // Determine whether or not we need to create a file.
+ childVFSD, err := rp.ResolveChild(parentVFSD, pc)
+ if err != nil {
+ return nil, err
+ }
+ if childVFSD == nil {
+ // Already checked for searchability above; now check for writability.
+ if err := parentInode.CheckPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ return nil, err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return nil, err
+ }
+ defer rp.Mount().EndWrite()
+ // Create and open the child.
+ child, err := parentInode.NewFile(ctx, pc, opts)
+ if err != nil {
+ return nil, err
+ }
+ parentVFSD.Impl().(*Dentry).InsertChild(pc, child)
+ return child.Impl().(*Dentry).inode.Open(rp, child, opts.Flags)
+ }
+ // Open existing file or follow symlink.
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
+ childDentry := childVFSD.Impl().(*Dentry)
+ childInode := childDentry.inode
+ if rp.ShouldFollowSymlink() {
+ if childDentry.isSymlink() {
+ target, err := childInode.Readlink(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if err := rp.HandleSymlink(target); err != nil {
+ return nil, err
+ }
+ // rp.Final() may no longer be true since we now need to resolve the
+ // symlink target.
+ goto afterTrailingSymlink
+ }
+ }
+ if err := childInode.CheckPermissions(rp.Credentials(), ats); err != nil {
+ return nil, err
+ }
+ return childInode.Open(rp, childVFSD, opts.Flags)
+}
+
+// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
+func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
+ fs.mu.RLock()
+ d, inode, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs()
+ if err != nil {
+ return "", err
+ }
+ if !d.Impl().(*Dentry).isSymlink() {
+ return "", syserror.EINVAL
+ }
+ return inode.Readlink(ctx)
+}
+
+// RenameAt implements vfs.FilesystemImpl.RenameAt.
+func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry, opts vfs.RenameOptions) error {
+ noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0
+ exchange := opts.Flags&linux.RENAME_EXCHANGE != 0
+ whiteout := opts.Flags&linux.RENAME_WHITEOUT != 0
+ if exchange && (noReplace || whiteout) {
+ // Can't specify RENAME_NOREPLACE or RENAME_WHITEOUT with RENAME_EXCHANGE.
+ return syserror.EINVAL
+ }
+ if exchange || whiteout {
+ // Exchange and Whiteout flags are not supported on kernfs.
+ return syserror.EINVAL
+ }
+
+ fs.mu.Lock()
+ defer fs.mu.Lock()
+
+ mnt := rp.Mount()
+ if mnt != vd.Mount() {
+ return syserror.EXDEV
+ }
+
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+
+ dstDirVFSD, dstDirInode, err := fs.walkParentDirLocked(ctx, rp)
+ fs.processDeferredDecRefsLocked()
+ if err != nil {
+ return err
+ }
+
+ srcVFSD := vd.Dentry()
+ srcDirVFSD := srcVFSD.Parent()
+
+ // Can we remove the src dentry?
+ if err := checkDeleteLocked(rp, srcVFSD); err != nil {
+ return err
+ }
+
+ // Can we create the dst dentry?
+ var dstVFSD *vfs.Dentry
+ pc, err := checkCreateLocked(rp, dstDirVFSD, dstDirInode)
+ switch err {
+ case nil:
+ // Ok, continue with rename as replacement.
+ case syserror.EEXIST:
+ if noReplace {
+ // Won't overwrite existing node since RENAME_NOREPLACE was requested.
+ return syserror.EEXIST
+ }
+ dstVFSD, err = rp.ResolveChild(dstDirVFSD, pc)
+ if err != nil {
+ panic(fmt.Sprintf("Child %q for parent Dentry %+v disappeared inside atomic section?", pc, dstDirVFSD))
+ }
+ default:
+ return err
+ }
+
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ virtfs := rp.VirtualFilesystem()
+
+ srcDirDentry := srcDirVFSD.Impl().(*Dentry)
+ dstDirDentry := dstDirVFSD.Impl().(*Dentry)
+
+ // We can't deadlock here due to lock ordering because we're protected from
+ // concurrent renames by fs.mu held for writing.
+ srcDirDentry.dirMu.Lock()
+ defer srcDirDentry.dirMu.Unlock()
+ dstDirDentry.dirMu.Lock()
+ defer dstDirDentry.dirMu.Unlock()
+
+ if err := virtfs.PrepareRenameDentry(mntns, srcVFSD, dstVFSD); err != nil {
+ return err
+ }
+ srcDirInode := srcDirDentry.inode
+ replaced, err := srcDirInode.Rename(ctx, srcVFSD.Name(), pc, srcVFSD, dstDirVFSD)
+ if err != nil {
+ virtfs.AbortRenameDentry(srcVFSD, dstVFSD)
+ return err
+ }
+ virtfs.CommitRenameReplaceDentry(srcVFSD, dstDirVFSD, pc, replaced)
+ return nil
+}
+
+// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
+func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ vfsd, inode, err := fs.walkExistingLocked(ctx, rp)
+ fs.processDeferredDecRefsLocked()
+ if err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+ if err := checkDeleteLocked(rp, vfsd); err != nil {
+ return err
+ }
+ if !vfsd.Impl().(*Dentry).isDir() {
+ return syserror.ENOTDIR
+ }
+ if inode.HasChildren() {
+ return syserror.ENOTEMPTY
+ }
+ virtfs := rp.VirtualFilesystem()
+ parentDentry := vfsd.Parent().Impl().(*Dentry)
+ parentDentry.dirMu.Lock()
+ defer parentDentry.dirMu.Unlock()
+ if err := virtfs.PrepareDeleteDentry(vfs.MountNamespaceFromContext(ctx), vfsd); err != nil {
+ return err
+ }
+ if err := parentDentry.inode.RmDir(ctx, rp.Component(), vfsd); err != nil {
+ virtfs.AbortDeleteDentry(vfsd)
+ return err
+ }
+ virtfs.CommitDeleteDentry(vfsd)
+ return nil
+}
+
+// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
+func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
+ fs.mu.RLock()
+ _, inode, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs()
+ if err != nil {
+ return err
+ }
+ if opts.Stat.Mask == 0 {
+ return nil
+ }
+ return inode.SetStat(fs.VFSFilesystem(), opts)
+}
+
+// StatAt implements vfs.FilesystemImpl.StatAt.
+func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
+ fs.mu.RLock()
+ _, inode, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs()
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ return inode.Stat(fs.VFSFilesystem()), nil
+}
+
+// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
+func (fs *Filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
+ fs.mu.RLock()
+ _, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs()
+ if err != nil {
+ return linux.Statfs{}, err
+ }
+ // TODO: actually implement statfs
+ return linux.Statfs{}, syserror.ENOSYS
+}
+
+// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
+func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
+ if rp.Done() {
+ return syserror.EEXIST
+ }
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
+ fs.processDeferredDecRefsLocked()
+ if err != nil {
+ return err
+ }
+ pc, err := checkCreateLocked(rp, parentVFSD, parentInode)
+ if err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+ child, err := parentInode.NewSymlink(ctx, pc, target)
+ if err != nil {
+ return err
+ }
+ parentVFSD.Impl().(*Dentry).InsertChild(pc, child)
+ return nil
+}
+
+// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
+func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ vfsd, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.processDeferredDecRefsLocked()
+ if err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+ if err := checkDeleteLocked(rp, vfsd); err != nil {
+ return err
+ }
+ if vfsd.Impl().(*Dentry).isDir() {
+ return syserror.EISDIR
+ }
+ virtfs := rp.VirtualFilesystem()
+ parentDentry := vfsd.Parent().Impl().(*Dentry)
+ parentDentry.dirMu.Lock()
+ defer parentDentry.dirMu.Unlock()
+ if err := virtfs.PrepareDeleteDentry(vfs.MountNamespaceFromContext(ctx), vfsd); err != nil {
+ return err
+ }
+ if err := parentDentry.inode.Unlink(ctx, rp.Component(), vfsd); err != nil {
+ virtfs.AbortDeleteDentry(vfsd)
+ return err
+ }
+ virtfs.CommitDeleteDentry(vfsd)
+ return nil
+}
+
+// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
+func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) {
+ fs.mu.RLock()
+ _, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs()
+ if err != nil {
+ return nil, err
+ }
+ // kernfs currently does not support extended attributes.
+ return nil, syserror.ENOTSUP
+}
+
+// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
+func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) {
+ fs.mu.RLock()
+ _, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs()
+ if err != nil {
+ return "", err
+ }
+ // kernfs currently does not support extended attributes.
+ return "", syserror.ENOTSUP
+}
+
+// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
+func (fs *Filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+ fs.mu.RLock()
+ _, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs()
+ if err != nil {
+ return err
+ }
+ // kernfs currently does not support extended attributes.
+ return syserror.ENOTSUP
+}
+
+// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
+func (fs *Filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+ fs.mu.RLock()
+ _, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs()
+ if err != nil {
+ return err
+ }
+ // kernfs currently does not support extended attributes.
+ return syserror.ENOTSUP
+}
+
+// PrependPath implements vfs.FilesystemImpl.PrependPath.
+func (fs *Filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ return vfs.GenericPrependPath(vfsroot, vd, b)
+}
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
new file mode 100644
index 000000000..7b45b702a
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -0,0 +1,492 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernfs
+
+import (
+ "fmt"
+ "sync"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// InodeNoopRefCount partially implements the Inode interface, specifically the
+// inodeRefs sub interface. InodeNoopRefCount implements a simple reference
+// count for inodes, performing no extra actions when references are obtained or
+// released. This is suitable for simple file inodes that don't reference any
+// resources.
+type InodeNoopRefCount struct {
+}
+
+// IncRef implements Inode.IncRef.
+func (n *InodeNoopRefCount) IncRef() {
+}
+
+// DecRef implements Inode.DecRef.
+func (n *InodeNoopRefCount) DecRef() {
+}
+
+// TryIncRef implements Inode.TryIncRef.
+func (n *InodeNoopRefCount) TryIncRef() bool {
+ return true
+}
+
+// Destroy implements Inode.Destroy.
+func (n *InodeNoopRefCount) Destroy() {
+}
+
+// InodeDirectoryNoNewChildren partially implements the Inode interface.
+// InodeDirectoryNoNewChildren represents a directory inode which does not
+// support creation of new children.
+type InodeDirectoryNoNewChildren struct{}
+
+// NewFile implements Inode.NewFile.
+func (*InodeDirectoryNoNewChildren) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+// NewDir implements Inode.NewDir.
+func (*InodeDirectoryNoNewChildren) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+// NewLink implements Inode.NewLink.
+func (*InodeDirectoryNoNewChildren) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+// NewSymlink implements Inode.NewSymlink.
+func (*InodeDirectoryNoNewChildren) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+// NewNode implements Inode.NewNode.
+func (*InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+// InodeNotDirectory partially implements the Inode interface, specifically the
+// inodeDirectory and inodeDynamicDirectory sub interfaces. Inodes that do not
+// represent directories can embed this to provide no-op implementations for
+// directory-related functions.
+type InodeNotDirectory struct {
+}
+
+// HasChildren implements Inode.HasChildren.
+func (*InodeNotDirectory) HasChildren() bool {
+ return false
+}
+
+// NewFile implements Inode.NewFile.
+func (*InodeNotDirectory) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) {
+ panic("NewFile called on non-directory inode")
+}
+
+// NewDir implements Inode.NewDir.
+func (*InodeNotDirectory) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) {
+ panic("NewDir called on non-directory inode")
+}
+
+// NewLink implements Inode.NewLinkink.
+func (*InodeNotDirectory) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) {
+ panic("NewLink called on non-directory inode")
+}
+
+// NewSymlink implements Inode.NewSymlink.
+func (*InodeNotDirectory) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) {
+ panic("NewSymlink called on non-directory inode")
+}
+
+// NewNode implements Inode.NewNode.
+func (*InodeNotDirectory) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) {
+ panic("NewNode called on non-directory inode")
+}
+
+// Unlink implements Inode.Unlink.
+func (*InodeNotDirectory) Unlink(context.Context, string, *vfs.Dentry) error {
+ panic("Unlink called on non-directory inode")
+}
+
+// RmDir implements Inode.RmDir.
+func (*InodeNotDirectory) RmDir(context.Context, string, *vfs.Dentry) error {
+ panic("RmDir called on non-directory inode")
+}
+
+// Rename implements Inode.Rename.
+func (*InodeNotDirectory) Rename(context.Context, string, string, *vfs.Dentry, *vfs.Dentry) (*vfs.Dentry, error) {
+ panic("Rename called on non-directory inode")
+}
+
+// Lookup implements Inode.Lookup.
+func (*InodeNotDirectory) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ panic("Lookup called on non-directory inode")
+}
+
+// Valid implements Inode.Valid.
+func (*InodeNotDirectory) Valid(context.Context) bool {
+ return true
+}
+
+// InodeNoDynamicLookup partially implements the Inode interface, specifically
+// the inodeDynamicLookup sub interface. Directory inodes that do not support
+// dymanic entries (i.e. entries that are not "hashed" into the
+// vfs.Dentry.children) can embed this to provide no-op implementations for
+// functions related to dynamic entries.
+type InodeNoDynamicLookup struct{}
+
+// Lookup implements Inode.Lookup.
+func (*InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ return nil, syserror.ENOENT
+}
+
+// Valid implements Inode.Valid.
+func (*InodeNoDynamicLookup) Valid(ctx context.Context) bool {
+ return true
+}
+
+// InodeNotSymlink partially implements the Inode interface, specifically the
+// inodeSymlink sub interface. All inodes that are not symlinks may embed this
+// to return the appropriate errors from symlink-related functions.
+type InodeNotSymlink struct{}
+
+// Readlink implements Inode.Readlink.
+func (*InodeNotSymlink) Readlink(context.Context) (string, error) {
+ return "", syserror.EINVAL
+}
+
+// InodeAttrs partially implements the Inode interface, specifically the
+// inodeMetadata sub interface. InodeAttrs provides functionality related to
+// inode attributes.
+//
+// Must be initialized by Init prior to first use.
+type InodeAttrs struct {
+ ino uint64
+ mode uint32
+ uid uint32
+ gid uint32
+ nlink uint32
+}
+
+// Init initializes this InodeAttrs.
+func (a *InodeAttrs) Init(creds *auth.Credentials, ino uint64, mode linux.FileMode) {
+ if mode.FileType() == 0 {
+ panic(fmt.Sprintf("No file type specified in 'mode' for InodeAttrs.Init(): mode=0%o", mode))
+ }
+
+ nlink := uint32(1)
+ if mode.FileType() == linux.ModeDirectory {
+ nlink = 2
+ }
+ atomic.StoreUint64(&a.ino, ino)
+ atomic.StoreUint32(&a.mode, uint32(mode))
+ atomic.StoreUint32(&a.uid, uint32(creds.EffectiveKUID))
+ atomic.StoreUint32(&a.gid, uint32(creds.EffectiveKGID))
+ atomic.StoreUint32(&a.nlink, nlink)
+}
+
+// Mode implements Inode.Mode.
+func (a *InodeAttrs) Mode() linux.FileMode {
+ return linux.FileMode(atomic.LoadUint32(&a.mode))
+}
+
+// Stat partially implements Inode.Stat. Note that this function doesn't provide
+// all the stat fields, and the embedder should consider extending the result
+// with filesystem-specific fields.
+func (a *InodeAttrs) Stat(*vfs.Filesystem) linux.Statx {
+ var stat linux.Statx
+ stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK
+ stat.Ino = atomic.LoadUint64(&a.ino)
+ stat.Mode = uint16(a.Mode())
+ stat.UID = atomic.LoadUint32(&a.uid)
+ stat.GID = atomic.LoadUint32(&a.gid)
+ stat.Nlink = atomic.LoadUint32(&a.nlink)
+
+ // TODO: Implement other stat fields like timestamps.
+
+ return stat
+}
+
+// SetStat implements Inode.SetStat.
+func (a *InodeAttrs) SetStat(_ *vfs.Filesystem, opts vfs.SetStatOptions) error {
+ stat := opts.Stat
+ if stat.Mask&linux.STATX_MODE != 0 {
+ for {
+ old := atomic.LoadUint32(&a.mode)
+ new := old | uint32(stat.Mode & ^uint16(linux.S_IFMT))
+ if swapped := atomic.CompareAndSwapUint32(&a.mode, old, new); swapped {
+ break
+ }
+ }
+ }
+
+ if stat.Mask&linux.STATX_UID != 0 {
+ atomic.StoreUint32(&a.uid, stat.UID)
+ }
+ if stat.Mask&linux.STATX_GID != 0 {
+ atomic.StoreUint32(&a.gid, stat.GID)
+ }
+
+ // Note that not all fields are modifiable. For example, the file type and
+ // inode numbers are immutable after node creation.
+
+ // TODO: Implement other stat fields like timestamps.
+
+ return nil
+}
+
+// CheckPermissions implements Inode.CheckPermissions.
+func (a *InodeAttrs) CheckPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
+ mode := a.Mode()
+ return vfs.GenericCheckPermissions(
+ creds,
+ ats,
+ mode.FileType() == linux.ModeDirectory,
+ uint16(mode),
+ auth.KUID(atomic.LoadUint32(&a.uid)),
+ auth.KGID(atomic.LoadUint32(&a.gid)),
+ )
+}
+
+// IncLinks implements Inode.IncLinks.
+func (a *InodeAttrs) IncLinks(n uint32) {
+ if atomic.AddUint32(&a.nlink, n) <= n {
+ panic("InodeLink.IncLinks called with no existing links")
+ }
+}
+
+// DecLinks implements Inode.DecLinks.
+func (a *InodeAttrs) DecLinks() {
+ if nlink := atomic.AddUint32(&a.nlink, ^uint32(0)); nlink == ^uint32(0) {
+ // Negative overflow
+ panic("Inode.DecLinks called at 0 links")
+ }
+}
+
+type slot struct {
+ Name string
+ Dentry *vfs.Dentry
+ slotEntry
+}
+
+// OrderedChildrenOptions contains initialization options for OrderedChildren.
+type OrderedChildrenOptions struct {
+ // Writable indicates whether vfs.FilesystemImpl methods implemented by
+ // OrderedChildren may modify the tracked children. This applies to
+ // operations related to rename, unlink and rmdir. If an OrderedChildren is
+ // not writable, these operations all fail with EPERM.
+ Writable bool
+}
+
+// OrderedChildren partially implements the Inode interface. OrderedChildren can
+// be embedded in directory inodes to keep track of the children in the
+// directory, and can then be used to implement a generic directory FD -- see
+// GenericDirectoryFD. OrderedChildren is not compatible with dynamic
+// directories.
+//
+// Must be initialize with Init before first use.
+type OrderedChildren struct {
+ refs.AtomicRefCount
+
+ // Can children be modified by user syscalls? It set to false, interface
+ // methods that would modify the children return EPERM. Immutable.
+ writable bool
+
+ mu sync.RWMutex
+ order slotList
+ set map[string]*slot
+}
+
+// Init initializes an OrderedChildren.
+func (o *OrderedChildren) Init(opts OrderedChildrenOptions) {
+ o.writable = opts.Writable
+ o.set = make(map[string]*slot)
+}
+
+// DecRef implements Inode.DecRef.
+func (o *OrderedChildren) DecRef() {
+ o.AtomicRefCount.DecRefWithDestructor(o.Destroy)
+}
+
+// Destroy cleans up resources referenced by this OrderedChildren.
+func (o *OrderedChildren) Destroy() {
+ o.mu.Lock()
+ defer o.mu.Unlock()
+ o.order.Reset()
+ o.set = nil
+}
+
+// Populate inserts children into this OrderedChildren, and d's dentry
+// cache. Populate returns the number of directories inserted, which the caller
+// may use to update the link count for the parent directory.
+//
+// Precondition: d.Impl() must be a kernfs Dentry. d must represent a directory
+// inode. children must not contain any conflicting entries already in o.
+func (o *OrderedChildren) Populate(d *Dentry, children map[string]*Dentry) uint32 {
+ var links uint32
+ for name, child := range children {
+ if child.isDir() {
+ links++
+ }
+ if err := o.Insert(name, child.VFSDentry()); err != nil {
+ panic(fmt.Sprintf("Collision when attempting to insert child %q (%+v) into %+v", name, child, d))
+ }
+ d.InsertChild(name, child.VFSDentry())
+ }
+ return links
+}
+
+// HasChildren implements Inode.HasChildren.
+func (o *OrderedChildren) HasChildren() bool {
+ o.mu.RLock()
+ defer o.mu.RUnlock()
+ return len(o.set) > 0
+}
+
+// Insert inserts child into o. This ignores the writability of o, as this is
+// not part of the vfs.FilesystemImpl interface, and is a lower-level operation.
+func (o *OrderedChildren) Insert(name string, child *vfs.Dentry) error {
+ o.mu.Lock()
+ defer o.mu.Unlock()
+ if _, ok := o.set[name]; ok {
+ return syserror.EEXIST
+ }
+ s := &slot{
+ Name: name,
+ Dentry: child,
+ }
+ o.order.PushBack(s)
+ o.set[name] = s
+ return nil
+}
+
+// Precondition: caller must hold o.mu for writing.
+func (o *OrderedChildren) removeLocked(name string) {
+ if s, ok := o.set[name]; ok {
+ delete(o.set, name)
+ o.order.Remove(s)
+ }
+}
+
+// Precondition: caller must hold o.mu for writing.
+func (o *OrderedChildren) replaceChildLocked(name string, new *vfs.Dentry) *vfs.Dentry {
+ if s, ok := o.set[name]; ok {
+ // Existing slot with given name, simply replace the dentry.
+ var old *vfs.Dentry
+ old, s.Dentry = s.Dentry, new
+ return old
+ }
+
+ // No existing slot with given name, create and hash new slot.
+ s := &slot{
+ Name: name,
+ Dentry: new,
+ }
+ o.order.PushBack(s)
+ o.set[name] = s
+ return nil
+}
+
+// Precondition: caller must hold o.mu for reading or writing.
+func (o *OrderedChildren) checkExistingLocked(name string, child *vfs.Dentry) error {
+ s, ok := o.set[name]
+ if !ok {
+ return syserror.ENOENT
+ }
+ if s.Dentry != child {
+ panic(fmt.Sprintf("Dentry hashed into inode doesn't match what vfs thinks! OrderedChild: %+v, vfs: %+v", s.Dentry, child))
+ }
+ return nil
+}
+
+// Unlink implements Inode.Unlink.
+func (o *OrderedChildren) Unlink(ctx context.Context, name string, child *vfs.Dentry) error {
+ if !o.writable {
+ return syserror.EPERM
+ }
+ o.mu.Lock()
+ defer o.mu.Unlock()
+ if err := o.checkExistingLocked(name, child); err != nil {
+ return err
+ }
+ o.removeLocked(name)
+ return nil
+}
+
+// Rmdir implements Inode.Rmdir.
+func (o *OrderedChildren) RmDir(ctx context.Context, name string, child *vfs.Dentry) error {
+ // We're not responsible for checking that child is a directory, that it's
+ // empty, or updating any link counts; so this is the same as unlink.
+ return o.Unlink(ctx, name, child)
+}
+
+type renameAcrossDifferentImplementationsError struct{}
+
+func (renameAcrossDifferentImplementationsError) Error() string {
+ return "rename across inodes with different implementations"
+}
+
+// Rename implements Inode.Rename.
+//
+// Precondition: Rename may only be called across two directory inodes with
+// identical implementations of Rename. Practically, this means filesystems that
+// implement Rename by embedding OrderedChildren for any directory
+// implementation must use OrderedChildren for all directory implementations
+// that will support Rename.
+//
+// Postcondition: reference on any replaced dentry transferred to caller.
+func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir *vfs.Dentry) (*vfs.Dentry, error) {
+ dst, ok := dstDir.Impl().(*Dentry).inode.(interface{}).(*OrderedChildren)
+ if !ok {
+ return nil, renameAcrossDifferentImplementationsError{}
+ }
+ if !o.writable || !dst.writable {
+ return nil, syserror.EPERM
+ }
+ // Note: There's a potential deadlock below if concurrent calls to Rename
+ // refer to the same src and dst directories in reverse. We avoid any
+ // ordering issues because the caller is required to serialize concurrent
+ // calls to Rename in accordance with the interface declaration.
+ o.mu.Lock()
+ defer o.mu.Unlock()
+ if dst != o {
+ dst.mu.Lock()
+ defer dst.mu.Unlock()
+ }
+ if err := o.checkExistingLocked(oldname, child); err != nil {
+ return nil, err
+ }
+ replaced := dst.replaceChildLocked(newname, child)
+ return replaced, nil
+}
+
+// nthLocked returns an iterator to the nth child tracked by this object. The
+// iterator is valid until the caller releases o.mu. Returns nil if the
+// requested index falls out of bounds.
+//
+// Preconditon: Caller must hold o.mu for reading.
+func (o *OrderedChildren) nthLocked(i int64) *slot {
+ for it := o.order.Front(); it != nil && i >= 0; it = it.Next() {
+ if i == 0 {
+ return it
+ }
+ i--
+ }
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
new file mode 100644
index 000000000..bb01c3d01
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -0,0 +1,405 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package kernfs provides the tools to implement inode-based filesystems.
+// Kernfs has two main features:
+//
+// 1. The Inode interface, which maps VFS2's path-based filesystem operations to
+// specific filesystem nodes. Kernfs uses the Inode interface to provide a
+// blanket implementation for the vfs.FilesystemImpl. Kernfs also serves as
+// the synchronization mechanism for all filesystem operations by holding a
+// filesystem-wide lock across all operations.
+//
+// 2. Various utility types which provide generic implementations for various
+// parts of the Inode and vfs.FileDescription interfaces. Client filesystems
+// based on kernfs can embed the appropriate set of these to avoid having to
+// reimplement common filesystem operations. See inode_impl_util.go and
+// fd_impl_util.go.
+//
+// Reference Model:
+//
+// Kernfs dentries represents named pointers to inodes. Dentries and inode have
+// independent lifetimes and reference counts. A child dentry unconditionally
+// holds a reference on its parent directory's dentry. A dentry also holds a
+// reference on the inode it points to. Multiple dentries can point to the same
+// inode (for example, in the case of hardlinks). File descriptors hold a
+// reference to the dentry they're opened on.
+//
+// Dentries are guaranteed to exist while holding Filesystem.mu for
+// reading. Dropping dentries require holding Filesystem.mu for writing. To
+// queue dentries for destruction from a read critical section, see
+// Filesystem.deferDecRef.
+//
+// Lock ordering:
+//
+// kernfs.Filesystem.mu
+// kernfs.Dentry.dirMu
+// vfs.VirtualFilesystem.mountMu
+// vfs.Dentry.mu
+// kernfs.Filesystem.droppedDentriesMu
+// (inode implementation locks, if any)
+package kernfs
+
+import (
+ "fmt"
+ "sync"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct{}
+
+// Filesystem mostly implements vfs.FilesystemImpl for a generic in-memory
+// filesystem. Concrete implementations are expected to embed this in their own
+// Filesystem type.
+type Filesystem struct {
+ vfsfs vfs.Filesystem
+
+ droppedDentriesMu sync.Mutex
+
+ // droppedDentries is a list of dentries waiting to be DecRef()ed. This is
+ // used to defer dentry destruction until mu can be acquired for
+ // writing. Protected by droppedDentriesMu.
+ droppedDentries []*vfs.Dentry
+
+ // mu synchronizes the lifetime of Dentries on this filesystem. Holding it
+ // for reading guarantees continued existence of any resolved dentries, but
+ // the dentry tree may be modified.
+ //
+ // Kernfs dentries can only be DecRef()ed while holding mu for writing. For
+ // example:
+ //
+ // fs.mu.Lock()
+ // defer fs.mu.Unlock()
+ // ...
+ // dentry1.DecRef()
+ // defer dentry2.DecRef() // Ok, will run before Unlock.
+ //
+ // If discarding dentries in a read context, use Filesystem.deferDecRef. For
+ // example:
+ //
+ // fs.mu.RLock()
+ // fs.mu.processDeferredDecRefs()
+ // defer fs.mu.RUnlock()
+ // ...
+ // fs.deferDecRef(dentry)
+ mu sync.RWMutex
+
+ // nextInoMinusOne is used to to allocate inode numbers on this
+ // filesystem. Must be accessed by atomic operations.
+ nextInoMinusOne uint64
+}
+
+// deferDecRef defers dropping a dentry ref until the next call to
+// processDeferredDecRefs{,Locked}. See comment on Filesystem.mu.
+//
+// Precondition: d must not already be pending destruction.
+func (fs *Filesystem) deferDecRef(d *vfs.Dentry) {
+ fs.droppedDentriesMu.Lock()
+ fs.droppedDentries = append(fs.droppedDentries, d)
+ fs.droppedDentriesMu.Unlock()
+}
+
+// processDeferredDecRefs calls vfs.Dentry.DecRef on all dentries in the
+// droppedDentries list. See comment on Filesystem.mu.
+func (fs *Filesystem) processDeferredDecRefs() {
+ fs.mu.Lock()
+ fs.processDeferredDecRefsLocked()
+ fs.mu.Unlock()
+}
+
+// Precondition: fs.mu must be held for writing.
+func (fs *Filesystem) processDeferredDecRefsLocked() {
+ fs.droppedDentriesMu.Lock()
+ for _, d := range fs.droppedDentries {
+ d.DecRef()
+ }
+ fs.droppedDentries = fs.droppedDentries[:0] // Keep slice memory for reuse.
+ fs.droppedDentriesMu.Unlock()
+}
+
+// Init initializes a kernfs filesystem. This should be called from during
+// vfs.FilesystemType.NewFilesystem for the concrete filesystem embedding
+// kernfs.
+func (fs *Filesystem) Init(vfsObj *vfs.VirtualFilesystem) {
+ fs.vfsfs.Init(vfsObj, fs)
+}
+
+// VFSFilesystem returns the generic vfs filesystem object.
+func (fs *Filesystem) VFSFilesystem() *vfs.Filesystem {
+ return &fs.vfsfs
+}
+
+// NextIno allocates a new inode number on this filesystem.
+func (fs *Filesystem) NextIno() uint64 {
+ return atomic.AddUint64(&fs.nextInoMinusOne, 1)
+}
+
+// These consts are used in the Dentry.flags field.
+const (
+ // Dentry points to a directory inode.
+ dflagsIsDir = 1 << iota
+
+ // Dentry points to a symlink inode.
+ dflagsIsSymlink
+)
+
+// Dentry implements vfs.DentryImpl.
+//
+// A kernfs dentry is similar to a dentry in a traditional filesystem: it's a
+// named reference to an inode. A dentry generally lives as long as it's part of
+// a mounted filesystem tree. Kernfs doesn't cache dentries once all references
+// to them are removed. Dentries hold a single reference to the inode they point
+// to, and child dentries hold a reference on their parent.
+//
+// Must be initialized by Init prior to first use.
+type Dentry struct {
+ refs.AtomicRefCount
+
+ vfsd vfs.Dentry
+ inode Inode
+
+ refs uint64
+
+ // flags caches useful information about the dentry from the inode. See the
+ // dflags* consts above. Must be accessed by atomic ops.
+ flags uint32
+
+ // dirMu protects vfsd.children for directory dentries.
+ dirMu sync.Mutex
+}
+
+// Init initializes this dentry.
+//
+// Precondition: Caller must hold a reference on inode.
+//
+// Postcondition: Caller's reference on inode is transferred to the dentry.
+func (d *Dentry) Init(inode Inode) {
+ d.vfsd.Init(d)
+ d.inode = inode
+ ftype := inode.Mode().FileType()
+ if ftype == linux.ModeDirectory {
+ d.flags |= dflagsIsDir
+ }
+ if ftype == linux.ModeSymlink {
+ d.flags |= dflagsIsSymlink
+ }
+}
+
+// VFSDentry returns the generic vfs dentry for this kernfs dentry.
+func (d *Dentry) VFSDentry() *vfs.Dentry {
+ return &d.vfsd
+}
+
+// isDir checks whether the dentry points to a directory inode.
+func (d *Dentry) isDir() bool {
+ return atomic.LoadUint32(&d.flags)&dflagsIsDir != 0
+}
+
+// isSymlink checks whether the dentry points to a symlink inode.
+func (d *Dentry) isSymlink() bool {
+ return atomic.LoadUint32(&d.flags)&dflagsIsSymlink != 0
+}
+
+// DecRef implements vfs.DentryImpl.DecRef.
+func (d *Dentry) DecRef() {
+ d.AtomicRefCount.DecRefWithDestructor(d.destroy)
+}
+
+// Precondition: Dentry must be removed from VFS' dentry cache.
+func (d *Dentry) destroy() {
+ d.inode.DecRef() // IncRef from Init.
+ d.inode = nil
+ if parent := d.vfsd.Parent(); parent != nil {
+ parent.DecRef() // IncRef from Dentry.InsertChild.
+ }
+}
+
+// InsertChild inserts child into the vfs dentry cache with the given name under
+// this dentry. This does not update the directory inode, so calling this on
+// it's own isn't sufficient to insert a child into a directory. InsertChild
+// updates the link count on d if required.
+//
+// Precondition: d must represent a directory inode.
+func (d *Dentry) InsertChild(name string, child *vfs.Dentry) {
+ if !d.isDir() {
+ panic(fmt.Sprintf("InsertChild called on non-directory Dentry: %+v.", d))
+ }
+ vfsDentry := d.VFSDentry()
+ vfsDentry.IncRef() // DecRef in child's Dentry.destroy.
+ d.dirMu.Lock()
+ vfsDentry.InsertChild(child, name)
+ d.dirMu.Unlock()
+}
+
+// The Inode interface maps filesystem-level operations that operate on paths to
+// equivalent operations on specific filesystem nodes.
+//
+// The interface methods are groups into logical categories as sub interfaces
+// below. Generally, an implementation for each sub interface can be provided by
+// embedding an appropriate type from inode_impl_utils.go. The sub interfaces
+// are purely organizational. Methods declared directly in the main interface
+// have no generic implementations, and should be explicitly provided by the
+// client filesystem.
+//
+// Generally, implementations are not responsible for tasks that are common to
+// all filesystems. These include:
+//
+// - Checking that dentries passed to methods are of the appropriate file type.
+// - Checking permissions.
+// - Updating link and reference counts.
+//
+// Specific responsibilities of implementations are documented below.
+type Inode interface {
+ // Methods related to reference counting. A generic implementation is
+ // provided by InodeNoopRefCount. These methods are generally called by the
+ // equivalent Dentry methods.
+ inodeRefs
+
+ // Methods related to node metadata. A generic implementation is provided by
+ // InodeAttrs.
+ inodeMetadata
+
+ // Method for inodes that represent symlink. InodeNotSymlink provides a
+ // blanket implementation for all non-symlink inodes.
+ inodeSymlink
+
+ // Method for inodes that represent directories. InodeNotDirectory provides
+ // a blanket implementation for all non-directory inodes.
+ inodeDirectory
+
+ // Method for inodes that represent dynamic directories and their
+ // children. InodeNoDynamicLookup provides a blanket implementation for all
+ // non-dynamic-directory inodes.
+ inodeDynamicLookup
+
+ // Open creates a file description for the filesystem object represented by
+ // this inode. The returned file description should hold a reference on the
+ // inode for its lifetime.
+ //
+ // Precondition: !rp.Done(). vfsd.Impl() must be a kernfs Dentry.
+ Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error)
+}
+
+type inodeRefs interface {
+ IncRef()
+ DecRef()
+ TryIncRef() bool
+ // Destroy is called when the inode reaches zero references. Destroy release
+ // all resources (references) on objects referenced by the inode, including
+ // any child dentries.
+ Destroy()
+}
+
+type inodeMetadata interface {
+ // CheckPermissions checks that creds may access this inode for the
+ // requested access type, per the the rules of
+ // fs/namei.c:generic_permission().
+ CheckPermissions(creds *auth.Credentials, atx vfs.AccessTypes) error
+
+ // Mode returns the (struct stat)::st_mode value for this inode. This is
+ // separated from Stat for performance.
+ Mode() linux.FileMode
+
+ // Stat returns the metadata for this inode. This corresponds to
+ // vfs.FilesystemImpl.StatAt.
+ Stat(fs *vfs.Filesystem) linux.Statx
+
+ // SetStat updates the metadata for this inode. This corresponds to
+ // vfs.FilesystemImpl.SetStatAt.
+ SetStat(fs *vfs.Filesystem, opts vfs.SetStatOptions) error
+}
+
+// Precondition: All methods in this interface may only be called on directory
+// inodes.
+type inodeDirectory interface {
+ // The New{File,Dir,Node,Symlink} methods below should return a new inode
+ // hashed into this inode.
+ //
+ // These inode constructors are inode-level operations rather than
+ // filesystem-level operations to allow client filesystems to mix different
+ // implementations based on the new node's location in the
+ // filesystem.
+
+ // HasChildren returns true if the directory inode has any children.
+ HasChildren() bool
+
+ // NewFile creates a new regular file inode.
+ NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*vfs.Dentry, error)
+
+ // NewDir creates a new directory inode.
+ NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*vfs.Dentry, error)
+
+ // NewLink creates a new hardlink to a specified inode in this
+ // directory. Implementations should create a new kernfs Dentry pointing to
+ // target, and update target's link count.
+ NewLink(ctx context.Context, name string, target Inode) (*vfs.Dentry, error)
+
+ // NewSymlink creates a new symbolic link inode.
+ NewSymlink(ctx context.Context, name, target string) (*vfs.Dentry, error)
+
+ // NewNode creates a new filesystem node for a mknod syscall.
+ NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*vfs.Dentry, error)
+
+ // Unlink removes a child dentry from this directory inode.
+ Unlink(ctx context.Context, name string, child *vfs.Dentry) error
+
+ // RmDir removes an empty child directory from this directory
+ // inode. Implementations must update the parent directory's link count,
+ // if required. Implementations are not responsible for checking that child
+ // is a directory, checking for an empty directory.
+ RmDir(ctx context.Context, name string, child *vfs.Dentry) error
+
+ // Rename is called on the source directory containing an inode being
+ // renamed. child should point to the resolved child in the source
+ // directory. If Rename replaces a dentry in the destination directory, it
+ // should return the replaced dentry or nil otherwise.
+ //
+ // Precondition: Caller must serialize concurrent calls to Rename.
+ Rename(ctx context.Context, oldname, newname string, child, dstDir *vfs.Dentry) (replaced *vfs.Dentry, err error)
+}
+
+type inodeDynamicLookup interface {
+ // Lookup should return an appropriate dentry if name should resolve to a
+ // child of this dynamic directory inode. This gives the directory an
+ // opportunity on every lookup to resolve additional entries that aren't
+ // hashed into the directory. This is only called when the inode is a
+ // directory. If the inode is not a directory, or if the directory only
+ // contains a static set of children, the implementer can unconditionally
+ // return an appropriate error (ENOTDIR and ENOENT respectively).
+ //
+ // The child returned by Lookup will be hashed into the VFS dentry tree. Its
+ // lifetime can be controlled by the filesystem implementation with an
+ // appropriate implementation of Valid.
+ //
+ // Lookup returns the child with an extra reference and the caller owns this
+ // reference.
+ Lookup(ctx context.Context, name string) (*vfs.Dentry, error)
+
+ // Valid should return true if this inode is still valid, or needs to
+ // be resolved again by a call to Lookup.
+ Valid(ctx context.Context) bool
+}
+
+type inodeSymlink interface {
+ // Readlink resolves the target of a symbolic link. If an inode is not a
+ // symlink, the implementation should return EINVAL.
+ Readlink(ctx context.Context) (string, error)
+}
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
new file mode 100644
index 000000000..f78bb7b04
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -0,0 +1,423 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernfs_test
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "runtime"
+ "sync"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const defaultMode linux.FileMode = 01777
+const staticFileContent = "This is sample content for a static test file."
+
+// RootDentryFn is a generator function for creating the root dentry of a test
+// filesystem. See newTestSystem.
+type RootDentryFn func(*auth.Credentials, *filesystem) *kernfs.Dentry
+
+// TestSystem represents the context for a single test.
+type TestSystem struct {
+ t *testing.T
+ ctx context.Context
+ creds *auth.Credentials
+ vfs *vfs.VirtualFilesystem
+ mns *vfs.MountNamespace
+ root vfs.VirtualDentry
+}
+
+// newTestSystem sets up a minimal environment for running a test, including an
+// instance of a test filesystem. Tests can control the contents of the
+// filesystem by providing an appropriate rootFn, which should return a
+// pre-populated root dentry.
+func newTestSystem(t *testing.T, rootFn RootDentryFn) *TestSystem {
+ ctx := contexttest.Context(t)
+ creds := auth.CredentialsFromContext(ctx)
+ v := vfs.New()
+ v.MustRegisterFilesystemType("testfs", &fsType{rootFn: rootFn})
+ mns, err := v.NewMountNamespace(ctx, creds, "", "testfs", &vfs.GetFilesystemOptions{})
+ if err != nil {
+ t.Fatalf("Failed to create testfs root mount: %v", err)
+ }
+
+ s := &TestSystem{
+ t: t,
+ ctx: ctx,
+ creds: creds,
+ vfs: v,
+ mns: mns,
+ root: mns.Root(),
+ }
+ runtime.SetFinalizer(s, func(s *TestSystem) { s.root.DecRef() })
+ return s
+}
+
+// PathOpAtRoot constructs a vfs.PathOperation for a path from the
+// root of the test filesystem.
+//
+// Precondition: path should be relative path.
+func (s *TestSystem) PathOpAtRoot(path string) vfs.PathOperation {
+ return vfs.PathOperation{
+ Root: s.root,
+ Start: s.root,
+ Pathname: path,
+ }
+}
+
+// GetDentryOrDie attempts to resolve a dentry referred to by the
+// provided path operation. If unsuccessful, the test fails.
+func (s *TestSystem) GetDentryOrDie(pop vfs.PathOperation) vfs.VirtualDentry {
+ vd, err := s.vfs.GetDentryAt(s.ctx, s.creds, &pop, &vfs.GetDentryOptions{})
+ if err != nil {
+ s.t.Fatalf("GetDentryAt(pop:%+v) failed: %v", pop, err)
+ }
+ return vd
+}
+
+func (s *TestSystem) ReadToEnd(fd *vfs.FileDescription) (string, error) {
+ buf := make([]byte, usermem.PageSize)
+ bufIOSeq := usermem.BytesIOSequence(buf)
+ opts := vfs.ReadOptions{}
+
+ var content bytes.Buffer
+ for {
+ n, err := fd.Impl().Read(s.ctx, bufIOSeq, opts)
+ if n == 0 || err != nil {
+ if err == io.EOF {
+ err = nil
+ }
+ return content.String(), err
+ }
+ content.Write(buf[:n])
+ }
+}
+
+type fsType struct {
+ rootFn RootDentryFn
+}
+
+type filesystem struct {
+ kernfs.Filesystem
+}
+
+type file struct {
+ kernfs.DynamicBytesFile
+ content string
+}
+
+func (fs *filesystem) newFile(creds *auth.Credentials, content string) *kernfs.Dentry {
+ f := &file{}
+ f.content = content
+ f.DynamicBytesFile.Init(creds, fs.NextIno(), f)
+
+ d := &kernfs.Dentry{}
+ d.Init(f)
+ return d
+}
+
+func (f *file) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "%s", f.content)
+ return nil
+}
+
+type attrs struct {
+ kernfs.InodeAttrs
+}
+
+func (a *attrs) SetStat(fs *vfs.Filesystem, opt vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+type readonlyDir struct {
+ attrs
+ kernfs.InodeNotSymlink
+ kernfs.InodeNoDynamicLookup
+ kernfs.InodeDirectoryNoNewChildren
+
+ kernfs.OrderedChildren
+ dentry kernfs.Dentry
+}
+
+func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry {
+ dir := &readonlyDir{}
+ dir.attrs.Init(creds, fs.NextIno(), linux.ModeDirectory|mode)
+ dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ dir.dentry.Init(dir)
+
+ dir.IncLinks(dir.OrderedChildren.Populate(&dir.dentry, contents))
+
+ return &dir.dentry
+}
+
+func (d *readonlyDir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+ fd := &kernfs.GenericDirectoryFD{}
+ fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, flags)
+ return fd.VFSFileDescription(), nil
+}
+
+type dir struct {
+ attrs
+ kernfs.InodeNotSymlink
+ kernfs.InodeNoDynamicLookup
+
+ fs *filesystem
+ dentry kernfs.Dentry
+ kernfs.OrderedChildren
+}
+
+func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry {
+ dir := &dir{}
+ dir.fs = fs
+ dir.attrs.Init(creds, fs.NextIno(), linux.ModeDirectory|mode)
+ dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true})
+ dir.dentry.Init(dir)
+
+ dir.IncLinks(dir.OrderedChildren.Populate(&dir.dentry, contents))
+
+ return &dir.dentry
+}
+
+func (d *dir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+ fd := &kernfs.GenericDirectoryFD{}
+ fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, flags)
+ return fd.VFSFileDescription(), nil
+}
+
+func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*vfs.Dentry, error) {
+ creds := auth.CredentialsFromContext(ctx)
+ dir := d.fs.newDir(creds, opts.Mode, nil)
+ dirVFSD := dir.VFSDentry()
+ if err := d.OrderedChildren.Insert(name, dirVFSD); err != nil {
+ dir.DecRef()
+ return nil, err
+ }
+ d.IncLinks(1)
+ return dirVFSD, nil
+}
+
+func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*vfs.Dentry, error) {
+ creds := auth.CredentialsFromContext(ctx)
+ f := d.fs.newFile(creds, "")
+ fVFSD := f.VFSDentry()
+ if err := d.OrderedChildren.Insert(name, fVFSD); err != nil {
+ f.DecRef()
+ return nil, err
+ }
+ return fVFSD, nil
+}
+
+func (*dir) NewLink(context.Context, string, kernfs.Inode) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+func (*dir) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+func (*dir) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+func (fst *fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opt vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ fs := &filesystem{}
+ fs.Init(vfsObj)
+ root := fst.rootFn(creds, fs)
+ return fs.VFSFilesystem(), root.VFSDentry(), nil
+}
+
+// -------------------- Remainder of the file are test cases --------------------
+
+func TestBasic(t *testing.T) {
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
+ return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ "file1": fs.newFile(creds, staticFileContent),
+ })
+ })
+ sys.GetDentryOrDie(sys.PathOpAtRoot("file1")).DecRef()
+}
+
+func TestMkdirGetDentry(t *testing.T) {
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
+ return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ "dir1": fs.newDir(creds, 0755, nil),
+ })
+ })
+
+ pop := sys.PathOpAtRoot("dir1/a new directory")
+ if err := sys.vfs.MkdirAt(sys.ctx, sys.creds, &pop, &vfs.MkdirOptions{Mode: 0755}); err != nil {
+ t.Fatalf("MkdirAt for PathOperation %+v failed: %v", pop, err)
+ }
+ sys.GetDentryOrDie(pop).DecRef()
+}
+
+func TestReadStaticFile(t *testing.T) {
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
+ return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ "file1": fs.newFile(creds, staticFileContent),
+ })
+ })
+
+ pop := sys.PathOpAtRoot("file1")
+ fd, err := sys.vfs.OpenAt(sys.ctx, sys.creds, &pop, &vfs.OpenOptions{})
+ if err != nil {
+ sys.t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err)
+ }
+ defer fd.DecRef()
+
+ content, err := sys.ReadToEnd(fd)
+ if err != nil {
+ sys.t.Fatalf("Read failed: %v", err)
+ }
+ if diff := cmp.Diff(staticFileContent, content); diff != "" {
+ sys.t.Fatalf("Read returned unexpected data:\n--- want\n+++ got\n%v", diff)
+ }
+}
+
+func TestCreateNewFileInStaticDir(t *testing.T) {
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
+ return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ "dir1": fs.newDir(creds, 0755, nil),
+ })
+ })
+
+ pop := sys.PathOpAtRoot("dir1/newfile")
+ opts := &vfs.OpenOptions{Flags: linux.O_CREAT | linux.O_EXCL, Mode: defaultMode}
+ fd, err := sys.vfs.OpenAt(sys.ctx, sys.creds, &pop, opts)
+ if err != nil {
+ sys.t.Fatalf("OpenAt(pop:%+v, opts:%+v) failed: %v", pop, opts, err)
+ }
+
+ // Close the file. The file should persist.
+ fd.DecRef()
+
+ fd, err = sys.vfs.OpenAt(sys.ctx, sys.creds, &pop, &vfs.OpenOptions{})
+ if err != nil {
+ sys.t.Fatalf("OpenAt(pop:%+v) = %+v failed: %v", pop, fd, err)
+ }
+ fd.DecRef()
+}
+
+// direntCollector provides an implementation for vfs.IterDirentsCallback for
+// testing. It simply iterates to the end of a given directory FD and collects
+// all dirents emitted by the callback.
+type direntCollector struct {
+ mu sync.Mutex
+ dirents map[string]vfs.Dirent
+}
+
+// Handle implements vfs.IterDirentsCallback.Handle.
+func (d *direntCollector) Handle(dirent vfs.Dirent) bool {
+ d.mu.Lock()
+ if d.dirents == nil {
+ d.dirents = make(map[string]vfs.Dirent)
+ }
+ d.dirents[dirent.Name] = dirent
+ d.mu.Unlock()
+ return true
+}
+
+// count returns the number of dirents currently in the collector.
+func (d *direntCollector) count() int {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ return len(d.dirents)
+}
+
+// contains checks whether the collector has a dirent with the given name and
+// type.
+func (d *direntCollector) contains(name string, typ uint8) error {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ dirent, ok := d.dirents[name]
+ if !ok {
+ return fmt.Errorf("No dirent named %q found", name)
+ }
+ if dirent.Type != typ {
+ return fmt.Errorf("Dirent named %q found, but was expecting type %d, got: %+v", name, typ, dirent)
+ }
+ return nil
+}
+
+func TestDirFDReadWrite(t *testing.T) {
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
+ return fs.newReadonlyDir(creds, 0755, nil)
+ })
+
+ pop := sys.PathOpAtRoot("/")
+ fd, err := sys.vfs.OpenAt(sys.ctx, sys.creds, &pop, &vfs.OpenOptions{})
+ if err != nil {
+ sys.t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err)
+ }
+ defer fd.DecRef()
+
+ // Read/Write should fail for directory FDs.
+ if _, err := fd.Read(sys.ctx, usermem.BytesIOSequence([]byte{}), vfs.ReadOptions{}); err != syserror.EISDIR {
+ sys.t.Fatalf("Read for directory FD failed with unexpected error: %v", err)
+ }
+ if _, err := fd.Write(sys.ctx, usermem.BytesIOSequence([]byte{}), vfs.WriteOptions{}); err != syserror.EISDIR {
+ sys.t.Fatalf("Wrire for directory FD failed with unexpected error: %v", err)
+ }
+}
+
+func TestDirFDIterDirents(t *testing.T) {
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
+ return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ // Fill root with nodes backed by various inode implementations.
+ "dir1": fs.newReadonlyDir(creds, 0755, nil),
+ "dir2": fs.newDir(creds, 0755, map[string]*kernfs.Dentry{
+ "dir3": fs.newDir(creds, 0755, nil),
+ }),
+ "file1": fs.newFile(creds, staticFileContent),
+ })
+ })
+
+ pop := sys.PathOpAtRoot("/")
+ fd, err := sys.vfs.OpenAt(sys.ctx, sys.creds, &pop, &vfs.OpenOptions{})
+ if err != nil {
+ sys.t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err)
+ }
+ defer fd.DecRef()
+
+ collector := &direntCollector{}
+ if err := fd.IterDirents(sys.ctx, collector); err != nil {
+ sys.t.Fatalf("IterDirent failed: %v", err)
+ }
+
+ // Root directory should contain ".", ".." and 3 children:
+ if collector.count() != 5 {
+ sys.t.Fatalf("IterDirent returned too many dirents")
+ }
+ for _, dirName := range []string{".", "..", "dir1", "dir2"} {
+ if err := collector.contains(dirName, linux.DT_DIR); err != nil {
+ sys.t.Fatalf("IterDirent had unexpected results: %v", err)
+ }
+ }
+ if err := collector.contains("file1", linux.DT_REG); err != nil {
+ sys.t.Fatalf("IterDirent had unexpected results: %v", err)
+ }
+
+}
diff --git a/pkg/sentry/fsimpl/memfs/BUILD b/pkg/sentry/fsimpl/memfs/BUILD
index 7e364c5fd..0cc751eb8 100644
--- a/pkg/sentry/fsimpl/memfs/BUILD
+++ b/pkg/sentry/fsimpl/memfs/BUILD
@@ -24,14 +24,19 @@ go_library(
"directory.go",
"filesystem.go",
"memfs.go",
+ "named_pipe.go",
"regular_file.go",
"symlink.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/memfs",
deps = [
"//pkg/abi/linux",
+ "//pkg/amutex",
+ "//pkg/fspath",
+ "//pkg/sentry/arch",
"//pkg/sentry/context",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/pipe",
"//pkg/sentry/usermem",
"//pkg/sentry/vfs",
"//pkg/syserror",
@@ -45,6 +50,7 @@ go_test(
deps = [
":memfs",
"//pkg/abi/linux",
+ "//pkg/refs",
"//pkg/sentry/context",
"//pkg/sentry/context/contexttest",
"//pkg/sentry/fs",
@@ -54,3 +60,19 @@ go_test(
"//pkg/syserror",
],
)
+
+go_test(
+ name = "memfs_test",
+ size = "small",
+ srcs = ["pipe_test.go"],
+ embed = [":memfs"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/context",
+ "//pkg/sentry/context/contexttest",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/usermem",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/memfs/benchmark_test.go b/pkg/sentry/fsimpl/memfs/benchmark_test.go
index a94b17db6..4a7a94a52 100644
--- a/pkg/sentry/fsimpl/memfs/benchmark_test.go
+++ b/pkg/sentry/fsimpl/memfs/benchmark_test.go
@@ -21,6 +21,7 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/context/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -160,6 +161,8 @@ func BenchmarkVFS1TmpfsStat(b *testing.B) {
b.Fatalf("stat(%q) failed: %v", filePath, err)
}
}
+ // Don't include deferred cleanup in benchmark time.
+ b.StopTimer()
})
}
}
@@ -173,10 +176,11 @@ func BenchmarkVFS2MemfsStat(b *testing.B) {
// Create VFS.
vfsObj := vfs.New()
vfsObj.MustRegisterFilesystemType("memfs", memfs.FilesystemType{})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.NewFilesystemOptions{})
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{})
if err != nil {
b.Fatalf("failed to create tmpfs root mount: %v", err)
}
+ defer mntns.DecRef(vfsObj)
var filePathBuilder strings.Builder
filePathBuilder.WriteByte('/')
@@ -186,7 +190,6 @@ func BenchmarkVFS2MemfsStat(b *testing.B) {
defer root.DecRef()
vd := root
vd.IncRef()
- defer vd.DecRef()
for i := depth; i > 0; i-- {
name := fmt.Sprintf("%d", i)
pop := vfs.PathOperation{
@@ -219,6 +222,8 @@ func BenchmarkVFS2MemfsStat(b *testing.B) {
Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL,
Mode: 0644,
})
+ vd.DecRef()
+ vd = vfs.VirtualDentry{}
if err != nil {
b.Fatalf("failed to create file %q: %v", filename, err)
}
@@ -243,6 +248,8 @@ func BenchmarkVFS2MemfsStat(b *testing.B) {
b.Fatalf("got wrong permissions (%0o)", stat.Mode)
}
}
+ // Don't include deferred cleanup in benchmark time.
+ b.StopTimer()
})
}
}
@@ -343,6 +350,8 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) {
b.Fatalf("stat(%q) failed: %v", filePath, err)
}
}
+ // Don't include deferred cleanup in benchmark time.
+ b.StopTimer()
})
}
}
@@ -356,10 +365,11 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
// Create VFS.
vfsObj := vfs.New()
vfsObj.MustRegisterFilesystemType("memfs", memfs.FilesystemType{})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.NewFilesystemOptions{})
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{})
if err != nil {
b.Fatalf("failed to create tmpfs root mount: %v", err)
}
+ defer mntns.DecRef(vfsObj)
var filePathBuilder strings.Builder
filePathBuilder.WriteByte('/')
@@ -384,7 +394,7 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
}
defer mountPoint.DecRef()
// Create and mount the submount.
- if err := vfsObj.NewMount(ctx, creds, "", &pop, "memfs", &vfs.NewFilesystemOptions{}); err != nil {
+ if err := vfsObj.MountAt(ctx, creds, "", &pop, "memfs", &vfs.MountOptions{}); err != nil {
b.Fatalf("failed to mount tmpfs submount: %v", err)
}
filePathBuilder.WriteString(mountPointName)
@@ -395,7 +405,6 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to walk to mount root: %v", err)
}
- defer vd.DecRef()
for i := depth; i > 0; i-- {
name := fmt.Sprintf("%d", i)
pop := vfs.PathOperation{
@@ -435,6 +444,7 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL,
Mode: 0644,
})
+ vd.DecRef()
if err != nil {
b.Fatalf("failed to create file %q: %v", filename, err)
}
@@ -459,6 +469,14 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
b.Fatalf("got wrong permissions (%0o)", stat.Mode)
}
}
+ // Don't include deferred cleanup in benchmark time.
+ b.StopTimer()
})
}
}
+
+func init() {
+ // Turn off reference leak checking for a fair comparison between vfs1 and
+ // vfs2.
+ refs.SetLeakMode(refs.NoLeakChecking)
+}
diff --git a/pkg/sentry/fsimpl/memfs/filesystem.go b/pkg/sentry/fsimpl/memfs/filesystem.go
index f79e2d9c8..af4389459 100644
--- a/pkg/sentry/fsimpl/memfs/filesystem.go
+++ b/pkg/sentry/fsimpl/memfs/filesystem.go
@@ -19,6 +19,7 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
@@ -159,7 +160,7 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op
return nil, err
}
}
- inode.incRef() // vfsd.IncRef(&fs.vfsfs)
+ inode.incRef()
return vfsd, nil
}
@@ -233,7 +234,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
if err != nil {
return err
}
- _, err = checkCreateLocked(rp, parentVFSD, parentInode)
+ pc, err := checkCreateLocked(rp, parentVFSD, parentInode)
if err != nil {
return err
}
@@ -241,17 +242,49 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
return err
}
defer rp.Mount().EndWrite()
- // TODO: actually implement mknod
- return syserror.EPERM
+
+ switch opts.Mode.FileType() {
+ case 0:
+ // "Zero file type is equivalent to type S_IFREG." - mknod(2)
+ fallthrough
+ case linux.ModeRegular:
+ // TODO(b/138862511): Implement.
+ return syserror.EINVAL
+
+ case linux.ModeNamedPipe:
+ child := fs.newDentry(fs.newNamedPipe(rp.Credentials(), opts.Mode))
+ parentVFSD.InsertChild(&child.vfsd, pc)
+ parentInode.impl.(*directory).childList.PushBack(child)
+ return nil
+
+ case linux.ModeSocket:
+ // TODO(b/138862511): Implement.
+ return syserror.EINVAL
+
+ case linux.ModeCharacterDevice:
+ fallthrough
+ case linux.ModeBlockDevice:
+ // TODO(b/72101894): We don't support creating block or character
+ // devices at the moment.
+ //
+ // When we start supporting block and character devices, we'll
+ // need to check for CAP_MKNOD here.
+ return syserror.EPERM
+
+ default:
+ // "EINVAL - mode requested creation of something other than a
+ // regular file, device special file, FIFO or socket." - mknod(2)
+ return syserror.EINVAL
+ }
}
// OpenAt implements vfs.FilesystemImpl.OpenAt.
func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
// Filter out flags that are not supported by memfs. O_DIRECTORY and
// O_NOFOLLOW have no effect here (they're handled by VFS by setting
- // appropriate bits in rp), but are returned by
- // FileDescriptionImpl.StatusFlags().
- opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW
+ // appropriate bits in rp), but are visible in FD status flags. O_NONBLOCK
+ // is supported only by pipes.
+ opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NONBLOCK
if opts.Flags&linux.O_CREAT == 0 {
fs.mu.RLock()
@@ -260,7 +293,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if err != nil {
return nil, err
}
- return inode.open(rp, vfsd, opts.Flags, false)
+ return inode.open(ctx, rp, vfsd, opts.Flags, false)
}
mustCreate := opts.Flags&linux.O_EXCL != 0
@@ -275,7 +308,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if mustCreate {
return nil, syserror.EEXIST
}
- return inode.open(rp, vfsd, opts.Flags, false)
+ return inode.open(ctx, rp, vfsd, opts.Flags, false)
}
afterTrailingSymlink:
// Walk to the parent directory of the last path component.
@@ -320,7 +353,7 @@ afterTrailingSymlink:
child := fs.newDentry(childInode)
vfsd.InsertChild(&child.vfsd, pc)
inode.impl.(*directory).childList.PushBack(child)
- return childInode.open(rp, &child.vfsd, opts.Flags, true)
+ return childInode.open(ctx, rp, &child.vfsd, opts.Flags, true)
}
// Open existing file or follow symlink.
if mustCreate {
@@ -336,29 +369,31 @@ afterTrailingSymlink:
// symlink target.
goto afterTrailingSymlink
}
- return childInode.open(rp, childVFSD, opts.Flags, false)
+ return childInode.open(ctx, rp, childVFSD, opts.Flags, false)
}
-func (i *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afterCreate bool) (*vfs.FileDescription, error) {
+func (i *inode) open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afterCreate bool) (*vfs.FileDescription, error) {
ats := vfs.AccessTypesForOpenFlags(flags)
if !afterCreate {
if err := i.checkPermissions(rp.Credentials(), ats, i.isDir()); err != nil {
return nil, err
}
}
+ mnt := rp.Mount()
switch impl := i.impl.(type) {
case *regularFile:
var fd regularFileFD
- fd.flags = flags
fd.readable = vfs.MayReadFileWithOpenFlags(flags)
fd.writable = vfs.MayWriteFileWithOpenFlags(flags)
if fd.writable {
- if err := rp.Mount().CheckBeginWrite(); err != nil {
+ if err := mnt.CheckBeginWrite(); err != nil {
return nil, err
}
- // Mount.EndWrite() is called by regularFileFD.Release().
+ // mnt.EndWrite() is called by regularFileFD.Release().
}
- fd.vfsfd.Init(&fd, rp.Mount(), vfsd)
+ mnt.IncRef()
+ vfsd.IncRef()
+ fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
if flags&linux.O_TRUNC != 0 {
impl.mu.Lock()
impl.data = impl.data[:0]
@@ -372,12 +407,15 @@ func (i *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afte
return nil, syserror.EISDIR
}
var fd directoryFD
- fd.vfsfd.Init(&fd, rp.Mount(), vfsd)
- fd.flags = flags
+ mnt.IncRef()
+ vfsd.IncRef()
+ fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
return &fd.vfsfd, nil
case *symlink:
// Can't open symlinks without O_PATH (which is unimplemented).
return nil, syserror.ELOOP
+ case *namedPipe:
+ return newNamedPipeFD(ctx, impl, rp, vfsd, flags)
default:
panic(fmt.Sprintf("unknown inode type: %T", i.impl))
}
@@ -542,3 +580,58 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
inode.decLinksLocked()
return nil
}
+
+// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
+func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ _, _, err := walkExistingLocked(rp)
+ if err != nil {
+ return nil, err
+ }
+ // TODO(b/127675828): support extended attributes
+ return nil, syserror.ENOTSUP
+}
+
+// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
+func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ _, _, err := walkExistingLocked(rp)
+ if err != nil {
+ return "", err
+ }
+ // TODO(b/127675828): support extended attributes
+ return "", syserror.ENOTSUP
+}
+
+// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
+func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ _, _, err := walkExistingLocked(rp)
+ if err != nil {
+ return err
+ }
+ // TODO(b/127675828): support extended attributes
+ return syserror.ENOTSUP
+}
+
+// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
+func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ _, _, err := walkExistingLocked(rp)
+ if err != nil {
+ return err
+ }
+ // TODO(b/127675828): support extended attributes
+ return syserror.ENOTSUP
+}
+
+// PrependPath implements vfs.FilesystemImpl.PrependPath.
+func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ return vfs.GenericPrependPath(vfsroot, vd, b)
+}
diff --git a/pkg/sentry/fsimpl/memfs/memfs.go b/pkg/sentry/fsimpl/memfs/memfs.go
index b78471c0f..9d509f6e4 100644
--- a/pkg/sentry/fsimpl/memfs/memfs.go
+++ b/pkg/sentry/fsimpl/memfs/memfs.go
@@ -52,10 +52,10 @@ type filesystem struct {
nextInoMinusOne uint64 // accessed using atomic memory operations
}
-// NewFilesystem implements vfs.FilesystemType.NewFilesystem.
-func (fstype FilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts vfs.NewFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
var fs filesystem
- fs.vfsfs.Init(&fs)
+ fs.vfsfs.Init(vfsObj, &fs)
root := fs.newDentry(fs.newDirectory(creds, 01777))
return &fs.vfsfs, &root.vfsd, nil
}
@@ -99,17 +99,17 @@ func (fs *filesystem) newDentry(inode *inode) *dentry {
}
// IncRef implements vfs.DentryImpl.IncRef.
-func (d *dentry) IncRef(vfsfs *vfs.Filesystem) {
+func (d *dentry) IncRef() {
d.inode.incRef()
}
// TryIncRef implements vfs.DentryImpl.TryIncRef.
-func (d *dentry) TryIncRef(vfsfs *vfs.Filesystem) bool {
+func (d *dentry) TryIncRef() bool {
return d.inode.tryIncRef()
}
// DecRef implements vfs.DentryImpl.DecRef.
-func (d *dentry) DecRef(vfsfs *vfs.Filesystem) {
+func (d *dentry) DecRef() {
d.inode.decRef()
}
@@ -227,6 +227,8 @@ func (i *inode) statTo(stat *linux.Statx) {
stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS
stat.Size = uint64(len(impl.target))
stat.Blocks = allocatedBlocksForSize(stat.Size)
+ case *namedPipe:
+ stat.Mode |= linux.S_IFIFO
default:
panic(fmt.Sprintf("unknown inode type: %T", i.impl))
}
@@ -259,28 +261,14 @@ func (i *inode) direntType() uint8 {
type fileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
-
- flags uint32 // status flags; immutable
}
func (fd *fileDescription) filesystem() *filesystem {
- return fd.vfsfd.VirtualDentry().Mount().Filesystem().Impl().(*filesystem)
+ return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem)
}
func (fd *fileDescription) inode() *inode {
- return fd.vfsfd.VirtualDentry().Dentry().Impl().(*dentry).inode
-}
-
-// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags.
-func (fd *fileDescription) StatusFlags(ctx context.Context) (uint32, error) {
- return fd.flags, nil
-}
-
-// SetStatusFlags implements vfs.FileDescriptionImpl.SetStatusFlags.
-func (fd *fileDescription) SetStatusFlags(ctx context.Context, flags uint32) error {
- // None of the flags settable by fcntl(F_SETFL) are supported, so this is a
- // no-op.
- return nil
+ return fd.vfsfd.Dentry().Impl().(*dentry).inode
}
// Stat implements vfs.FileDescriptionImpl.Stat.
diff --git a/pkg/sentry/fsimpl/memfs/named_pipe.go b/pkg/sentry/fsimpl/memfs/named_pipe.go
new file mode 100644
index 000000000..d5060850e
--- /dev/null
+++ b/pkg/sentry/fsimpl/memfs/named_pipe.go
@@ -0,0 +1,62 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package memfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+type namedPipe struct {
+ inode inode
+
+ pipe *pipe.VFSPipe
+}
+
+// Preconditions:
+// * fs.mu must be locked.
+// * rp.Mount().CheckBeginWrite() has been called successfully.
+func (fs *filesystem) newNamedPipe(creds *auth.Credentials, mode linux.FileMode) *inode {
+ file := &namedPipe{pipe: pipe.NewVFSPipe(pipe.DefaultPipeSize, usermem.PageSize)}
+ file.inode.init(file, fs, creds, mode)
+ file.inode.nlink = 1 // Only the parent has a link.
+ return &file.inode
+}
+
+// namedPipeFD implements vfs.FileDescriptionImpl. Methods are implemented
+// entirely via struct embedding.
+type namedPipeFD struct {
+ fileDescription
+
+ *pipe.VFSPipeFD
+}
+
+func newNamedPipeFD(ctx context.Context, np *namedPipe, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+ var err error
+ var fd namedPipeFD
+ fd.VFSPipeFD, err = np.pipe.NewVFSPipeFD(ctx, rp, vfsd, &fd.vfsfd, flags)
+ if err != nil {
+ return nil, err
+ }
+ mnt := rp.Mount()
+ mnt.IncRef()
+ vfsd.IncRef()
+ fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
+ return &fd.vfsfd, nil
+}
diff --git a/pkg/sentry/fsimpl/memfs/pipe_test.go b/pkg/sentry/fsimpl/memfs/pipe_test.go
new file mode 100644
index 000000000..5bf527c80
--- /dev/null
+++ b/pkg/sentry/fsimpl/memfs/pipe_test.go
@@ -0,0 +1,233 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package memfs
+
+import (
+ "bytes"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const fileName = "mypipe"
+
+func TestSeparateFDs(t *testing.T) {
+ ctx, creds, vfsObj, root := setup(t)
+ defer root.DecRef()
+
+ // Open the read side. This is done in a concurrently because opening
+ // One end the pipe blocks until the other end is opened.
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Pathname: fileName,
+ FollowFinalSymlink: true,
+ }
+ rfdchan := make(chan *vfs.FileDescription)
+ go func() {
+ openOpts := vfs.OpenOptions{Flags: linux.O_RDONLY}
+ rfd, _ := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ rfdchan <- rfd
+ }()
+
+ // Open the write side.
+ openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY}
+ wfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ if err != nil {
+ t.Fatalf("failed to open pipe for writing %q: %v", fileName, err)
+ }
+ defer wfd.DecRef()
+
+ rfd, ok := <-rfdchan
+ if !ok {
+ t.Fatalf("failed to open pipe for reading %q", fileName)
+ }
+ defer rfd.DecRef()
+
+ const msg = "vamos azul"
+ checkEmpty(ctx, t, rfd)
+ checkWrite(ctx, t, wfd, msg)
+ checkRead(ctx, t, rfd, msg)
+}
+
+func TestNonblockingRead(t *testing.T) {
+ ctx, creds, vfsObj, root := setup(t)
+ defer root.DecRef()
+
+ // Open the read side as nonblocking.
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Pathname: fileName,
+ FollowFinalSymlink: true,
+ }
+ openOpts := vfs.OpenOptions{Flags: linux.O_RDONLY | linux.O_NONBLOCK}
+ rfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ if err != nil {
+ t.Fatalf("failed to open pipe for reading %q: %v", fileName, err)
+ }
+ defer rfd.DecRef()
+
+ // Open the write side.
+ openOpts = vfs.OpenOptions{Flags: linux.O_WRONLY}
+ wfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ if err != nil {
+ t.Fatalf("failed to open pipe for writing %q: %v", fileName, err)
+ }
+ defer wfd.DecRef()
+
+ const msg = "geh blau"
+ checkEmpty(ctx, t, rfd)
+ checkWrite(ctx, t, wfd, msg)
+ checkRead(ctx, t, rfd, msg)
+}
+
+func TestNonblockingWriteError(t *testing.T) {
+ ctx, creds, vfsObj, root := setup(t)
+ defer root.DecRef()
+
+ // Open the write side as nonblocking, which should return ENXIO.
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Pathname: fileName,
+ FollowFinalSymlink: true,
+ }
+ openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY | linux.O_NONBLOCK}
+ _, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ if err != syserror.ENXIO {
+ t.Fatalf("expected ENXIO, but got error: %v", err)
+ }
+}
+
+func TestSingleFD(t *testing.T) {
+ ctx, creds, vfsObj, root := setup(t)
+ defer root.DecRef()
+
+ // Open the pipe as readable and writable.
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Pathname: fileName,
+ FollowFinalSymlink: true,
+ }
+ openOpts := vfs.OpenOptions{Flags: linux.O_RDWR}
+ fd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ if err != nil {
+ t.Fatalf("failed to open pipe for writing %q: %v", fileName, err)
+ }
+ defer fd.DecRef()
+
+ const msg = "forza blu"
+ checkEmpty(ctx, t, fd)
+ checkWrite(ctx, t, fd, msg)
+ checkRead(ctx, t, fd, msg)
+}
+
+// setup creates a VFS with a pipe in the root directory at path fileName. The
+// returned VirtualDentry must be DecRef()'d be the caller. It calls t.Fatal
+// upon failure.
+func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesystem, vfs.VirtualDentry) {
+ ctx := contexttest.Context(t)
+ creds := auth.CredentialsFromContext(ctx)
+
+ // Create VFS.
+ vfsObj := vfs.New()
+ vfsObj.MustRegisterFilesystemType("memfs", FilesystemType{})
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.GetFilesystemOptions{})
+ if err != nil {
+ t.Fatalf("failed to create tmpfs root mount: %v", err)
+ }
+
+ // Create the pipe.
+ root := mntns.Root()
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Pathname: fileName,
+ FollowFinalSymlink: true,
+ }
+ mknodOpts := vfs.MknodOptions{Mode: linux.ModeNamedPipe | 0644}
+ if err := vfsObj.MknodAt(ctx, creds, &pop, &mknodOpts); err != nil {
+ t.Fatalf("failed to create file %q: %v", fileName, err)
+ }
+
+ // Sanity check: the file pipe exists and has the correct mode.
+ stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Pathname: fileName,
+ FollowFinalSymlink: true,
+ }, &vfs.StatOptions{})
+ if err != nil {
+ t.Fatalf("stat(%q) failed: %v", fileName, err)
+ }
+ if stat.Mode&^linux.S_IFMT != 0644 {
+ t.Errorf("got wrong permissions (%0o)", stat.Mode)
+ }
+ if stat.Mode&linux.S_IFMT != linux.ModeNamedPipe {
+ t.Errorf("got wrong file type (%0o)", stat.Mode)
+ }
+
+ return ctx, creds, vfsObj, root
+}
+
+// checkEmpty calls t.Fatal if the pipe in fd is not empty.
+func checkEmpty(ctx context.Context, t *testing.T, fd *vfs.FileDescription) {
+ readData := make([]byte, 1)
+ dst := usermem.BytesIOSequence(readData)
+ bytesRead, err := fd.Read(ctx, dst, vfs.ReadOptions{})
+ if err != syserror.ErrWouldBlock {
+ t.Fatalf("expected ErrWouldBlock reading from empty pipe %q, but got: %v", fileName, err)
+ }
+ if bytesRead != 0 {
+ t.Fatalf("expected to read 0 bytes, but got %d", bytesRead)
+ }
+}
+
+// checkWrite calls t.Fatal if it fails to write all of msg to fd.
+func checkWrite(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) {
+ writeData := []byte(msg)
+ src := usermem.BytesIOSequence(writeData)
+ bytesWritten, err := fd.Write(ctx, src, vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("error writing to pipe %q: %v", fileName, err)
+ }
+ if bytesWritten != int64(len(writeData)) {
+ t.Fatalf("expected to write %d bytes, but wrote %d", len(writeData), bytesWritten)
+ }
+}
+
+// checkRead calls t.Fatal if it fails to read msg from fd.
+func checkRead(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) {
+ readData := make([]byte, len(msg))
+ dst := usermem.BytesIOSequence(readData)
+ bytesRead, err := fd.Read(ctx, dst, vfs.ReadOptions{})
+ if err != nil {
+ t.Fatalf("error reading from pipe %q: %v", fileName, err)
+ }
+ if bytesRead != int64(len(msg)) {
+ t.Fatalf("expected to read %d bytes, but got %d", len(msg), bytesRead)
+ }
+ if !bytes.Equal(readData, []byte(msg)) {
+ t.Fatalf("expected to read %q from pipe, but got %q", msg, string(readData))
+ }
+}
diff --git a/pkg/sentry/fsimpl/proc/filesystems.go b/pkg/sentry/fsimpl/proc/filesystems.go
index c36c4aff5..0e016bca5 100644
--- a/pkg/sentry/fsimpl/proc/filesystems.go
+++ b/pkg/sentry/fsimpl/proc/filesystems.go
@@ -19,7 +19,7 @@ package proc
// +stateify savable
type filesystemsData struct{}
-// TODO(b/138862512): Implement vfs.DynamicBytesSource.Generate for
+// TODO(gvisor.dev/issue/1195): Implement vfs.DynamicBytesSource.Generate for
// filesystemsData. We would need to retrive filesystem names from
// vfs.VirtualFilesystem. Also needs vfs replacement for
// fs.Filesystem.AllowUserList() and fs.FilesystemRequiresDev.
diff --git a/pkg/sentry/fsimpl/proc/mounts.go b/pkg/sentry/fsimpl/proc/mounts.go
index e81b1e910..8683cf677 100644
--- a/pkg/sentry/fsimpl/proc/mounts.go
+++ b/pkg/sentry/fsimpl/proc/mounts.go
@@ -16,7 +16,7 @@ package proc
import "gvisor.dev/gvisor/pkg/sentry/kernel"
-// TODO(b/138862512): Implement mountInfoFile and mountsFile.
+// TODO(gvisor.dev/issue/1195): Implement mountInfoFile and mountsFile.
// mountInfoFile implements vfs.DynamicBytesSource for /proc/[pid]/mountinfo.
//
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index c46e05c3a..0d87be52b 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -229,6 +229,10 @@ func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error {
fmt.Fprintf(buf, "CapEff:\t%016x\n", creds.EffectiveCaps)
fmt.Fprintf(buf, "CapBnd:\t%016x\n", creds.BoundingCaps)
fmt.Fprintf(buf, "Seccomp:\t%d\n", s.t.SeccompMode())
+ // We unconditionally report a single NUMA node. See
+ // pkg/sentry/syscalls/linux/sys_mempolicy.go.
+ fmt.Fprintf(buf, "Mems_allowed:\t1\n")
+ fmt.Fprintf(buf, "Mems_allowed_list:\t0\n")
return nil
}
diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD
index d5284f0d9..8d60ad4ad 100644
--- a/pkg/sentry/inet/BUILD
+++ b/pkg/sentry/inet/BUILD
@@ -13,5 +13,8 @@ go_library(
"test_stack.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/inet",
- deps = ["//pkg/sentry/context"],
+ deps = [
+ "//pkg/sentry/context",
+ "//pkg/tcpip/stack",
+ ],
)
diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go
index 80f227dbe..a7dfb78a7 100644
--- a/pkg/sentry/inet/inet.go
+++ b/pkg/sentry/inet/inet.go
@@ -15,6 +15,8 @@
// Package inet defines semantics for IP stacks.
package inet
+import "gvisor.dev/gvisor/pkg/tcpip/stack"
+
// Stack represents a TCP/IP stack.
type Stack interface {
// Interfaces returns all network interfaces as a mapping from interface
@@ -58,6 +60,16 @@ type Stack interface {
// Resume restarts the network stack after restore.
Resume()
+
+ // RegisteredEndpoints returns all endpoints which are currently registered.
+ RegisteredEndpoints() []stack.TransportEndpoint
+
+ // CleanupEndpoints returns endpoints currently in the cleanup state.
+ CleanupEndpoints() []stack.TransportEndpoint
+
+ // RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful
+ // for restoring a stack after a save.
+ RestoreCleanupEndpoints([]stack.TransportEndpoint)
}
// Interface contains information about a network interface.
@@ -153,3 +165,23 @@ type Route struct {
// GatewayAddr is the route gateway address (RTA_GATEWAY).
GatewayAddr []byte
}
+
+// Below SNMP metrics are from Linux/usr/include/linux/snmp.h.
+
+// StatSNMPIP describes Ip line of /proc/net/snmp.
+type StatSNMPIP [19]uint64
+
+// StatSNMPICMP describes Icmp line of /proc/net/snmp.
+type StatSNMPICMP [27]uint64
+
+// StatSNMPICMPMSG describes IcmpMsg line of /proc/net/snmp.
+type StatSNMPICMPMSG [512]uint64
+
+// StatSNMPTCP describes Tcp line of /proc/net/snmp.
+type StatSNMPTCP [15]uint64
+
+// StatSNMPUDP describes Udp line of /proc/net/snmp.
+type StatSNMPUDP [8]uint64
+
+// StatSNMPUDPLite describes UdpLite line of /proc/net/snmp.
+type StatSNMPUDPLite [8]uint64
diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go
index b9eed7c3a..dcfcbd97e 100644
--- a/pkg/sentry/inet/test_stack.go
+++ b/pkg/sentry/inet/test_stack.go
@@ -14,6 +14,8 @@
package inet
+import "gvisor.dev/gvisor/pkg/tcpip/stack"
+
// TestStack is a dummy implementation of Stack for tests.
type TestStack struct {
InterfacesMap map[int32]Interface
@@ -94,5 +96,17 @@ func (s *TestStack) RouteTable() []Route {
}
// Resume implements Stack.Resume.
-func (s *TestStack) Resume() {
+func (s *TestStack) Resume() {}
+
+// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints.
+func (s *TestStack) RegisteredEndpoints() []stack.TransportEndpoint {
+ return nil
}
+
+// CleanupEndpoints implements inet.Stack.CleanupEndpoints.
+func (s *TestStack) CleanupEndpoints() []stack.TransportEndpoint {
+ return nil
+}
+
+// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
+func (s *TestStack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {}
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index e041c51b3..2706927ff 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -35,7 +35,7 @@ go_template_instance(
out = "seqatomic_taskgoroutineschedinfo_unsafe.go",
package = "kernel",
suffix = "TaskGoroutineSchedInfo",
- template = "//third_party/gvsync:generic_seqatomic",
+ template = "//pkg/syncutil:generic_seqatomic",
types = {
"Value": "TaskGoroutineSchedInfo",
},
@@ -209,12 +209,12 @@ go_library(
"//pkg/sentry/usermem",
"//pkg/state",
"//pkg/state/statefile",
+ "//pkg/syncutil",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
"//pkg/tcpip/stack",
"//pkg/waiter",
- "//third_party/gvsync",
],
)
diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD
index 51de4568a..04c244447 100644
--- a/pkg/sentry/kernel/auth/BUILD
+++ b/pkg/sentry/kernel/auth/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "atomicptr_credentials_unsafe.go",
package = "auth",
suffix = "Credentials",
- template = "//third_party/gvsync:generic_atomicptr",
+ template = "//pkg/syncutil:generic_atomicptr",
types = {
"Value": "Credentials",
},
diff --git a/pkg/sentry/kernel/context.go b/pkg/sentry/kernel/context.go
index e3f5b0d83..3c9dceaba 100644
--- a/pkg/sentry/kernel/context.go
+++ b/pkg/sentry/kernel/context.go
@@ -15,6 +15,8 @@
package kernel
import (
+ "time"
+
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/context"
)
@@ -97,6 +99,21 @@ func TaskFromContext(ctx context.Context) *Task {
return nil
}
+// Deadline implements context.Context.Deadline.
+func (*Task) Deadline() (time.Time, bool) {
+ return time.Time{}, false
+}
+
+// Done implements context.Context.Done.
+func (*Task) Done() <-chan struct{} {
+ return nil
+}
+
+// Err implements context.Context.Err.
+func (*Task) Err() error {
+ return nil
+}
+
// AsyncContext returns a context.Context that may be used by goroutines that
// do work on behalf of t and therefore share its contextual values, but are
// not t's task goroutine (e.g. asynchronous I/O).
@@ -129,6 +146,21 @@ func (ctx taskAsyncContext) IsLogging(level log.Level) bool {
return ctx.t.IsLogging(level)
}
+// Deadline implements context.Context.Deadline.
+func (ctx taskAsyncContext) Deadline() (time.Time, bool) {
+ return ctx.t.Deadline()
+}
+
+// Done implements context.Context.Done.
+func (ctx taskAsyncContext) Done() <-chan struct{} {
+ return ctx.t.Done()
+}
+
+// Err implements context.Context.Err.
+func (ctx taskAsyncContext) Err() error {
+ return ctx.t.Err()
+}
+
// Value implements context.Context.Value.
func (ctx taskAsyncContext) Value(key interface{}) interface{} {
return ctx.t.Value(key)
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index cc3f43a45..11f613a11 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -81,6 +81,9 @@ type FDTable struct {
// mu protects below.
mu sync.Mutex `state:"nosave"`
+ // next is start position to find fd.
+ next int32
+
// used contains the number of non-nil entries. It must be accessed
// atomically. It may be read atomically without holding mu (but not
// written).
@@ -226,6 +229,11 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
f.mu.Lock()
defer f.mu.Unlock()
+ // From f.next to find available fd.
+ if fd < f.next {
+ fd = f.next
+ }
+
// Install all entries.
for i := fd; i < end && len(fds) < len(files); i++ {
if d, _, _ := f.get(i); d == nil {
@@ -242,6 +250,11 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
return nil, syscall.EMFILE
}
+ if fd == f.next {
+ // Update next search start position.
+ f.next = fds[len(fds)-1] + 1
+ }
+
return fds, nil
}
@@ -361,6 +374,11 @@ func (f *FDTable) Remove(fd int32) *fs.File {
f.mu.Lock()
defer f.mu.Unlock()
+ // Update current available position.
+ if fd < f.next {
+ f.next = fd
+ }
+
orig, _, _ := f.get(fd)
if orig != nil {
orig.IncRef() // Reference for caller.
@@ -377,6 +395,10 @@ func (f *FDTable) RemoveIf(cond func(*fs.File, FDFlags) bool) {
f.forEach(func(fd int32, file *fs.File, flags FDFlags) {
if cond(file, flags) {
f.set(fd, nil, FDFlags{}) // Clear from table.
+ // Update current available position.
+ if fd < f.next {
+ f.next = fd
+ }
}
})
}
diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go
index 2413788e7..2bcb6216a 100644
--- a/pkg/sentry/kernel/fd_table_test.go
+++ b/pkg/sentry/kernel/fd_table_test.go
@@ -70,6 +70,42 @@ func TestFDTableMany(t *testing.T) {
if err := fdTable.NewFDAt(ctx, 1, file, FDFlags{}); err != nil {
t.Fatalf("fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err)
}
+
+ i := int32(2)
+ fdTable.Remove(i)
+ if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != i {
+ t.Fatalf("Allocated %v FDs but wanted to allocate %v: %v", i, maxFD, err)
+ }
+ })
+}
+
+func TestFDTableOverLimit(t *testing.T) {
+ runTest(t, func(ctx context.Context, fdTable *FDTable, file *fs.File, _ *limits.LimitSet) {
+ if _, err := fdTable.NewFDs(ctx, maxFD, []*fs.File{file}, FDFlags{}); err == nil {
+ t.Fatalf("fdTable.NewFDs(maxFD, f): got nil, wanted error")
+ }
+
+ if _, err := fdTable.NewFDs(ctx, maxFD-2, []*fs.File{file, file, file}, FDFlags{}); err == nil {
+ t.Fatalf("fdTable.NewFDs(maxFD-2, {f,f,f}): got nil, wanted error")
+ }
+
+ if fds, err := fdTable.NewFDs(ctx, maxFD-3, []*fs.File{file, file, file}, FDFlags{}); err != nil {
+ t.Fatalf("fdTable.NewFDs(maxFD-3, {f,f,f}): got %v, wanted nil", err)
+ } else {
+ for _, fd := range fds {
+ fdTable.Remove(fd)
+ }
+ }
+
+ if fds, err := fdTable.NewFDs(ctx, maxFD-1, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != maxFD-1 {
+ t.Fatalf("fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err)
+ }
+
+ if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil {
+ t.Fatalf("Adding an FD to a resized map: got %v, want nil", err)
+ } else if len(fds) != 1 || fds[0] != 0 {
+ t.Fatalf("Added an FD to a resized map: got %v, want {1}", fds)
+ }
})
}
diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD
index 34286c7a8..75ec31761 100644
--- a/pkg/sentry/kernel/futex/BUILD
+++ b/pkg/sentry/kernel/futex/BUILD
@@ -9,7 +9,7 @@ go_template_instance(
out = "atomicptr_bucket_unsafe.go",
package = "futex",
suffix = "Bucket",
- template = "//third_party/gvsync:generic_atomicptr",
+ template = "//pkg/syncutil:generic_atomicptr",
types = {
"Value": "bucket",
},
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 3cda03891..bd3fb4c03 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -391,7 +391,7 @@ func (k *Kernel) SaveTo(w io.Writer) error {
//
// N.B. This will also be saved along with the full kernel save below.
cpuidStart := time.Now()
- if err := state.Save(w, k.FeatureSet(), nil); err != nil {
+ if err := state.Save(k.SupervisorContext(), w, k.FeatureSet(), nil); err != nil {
return err
}
log.Infof("CPUID save took [%s].", time.Since(cpuidStart))
@@ -399,7 +399,7 @@ func (k *Kernel) SaveTo(w io.Writer) error {
// Save the kernel state.
kernelStart := time.Now()
var stats state.Stats
- if err := state.Save(w, k, &stats); err != nil {
+ if err := state.Save(k.SupervisorContext(), w, k, &stats); err != nil {
return err
}
log.Infof("Kernel save stats: %s", &stats)
@@ -407,7 +407,7 @@ func (k *Kernel) SaveTo(w io.Writer) error {
// Save the memory file's state.
memoryStart := time.Now()
- if err := k.mf.SaveTo(w); err != nil {
+ if err := k.mf.SaveTo(k.SupervisorContext(), w); err != nil {
return err
}
log.Infof("Memory save took [%s].", time.Since(memoryStart))
@@ -542,7 +542,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks)
// don't need to explicitly install it in the Kernel.
cpuidStart := time.Now()
var features cpuid.FeatureSet
- if err := state.Load(r, &features, nil); err != nil {
+ if err := state.Load(k.SupervisorContext(), r, &features, nil); err != nil {
return err
}
log.Infof("CPUID load took [%s].", time.Since(cpuidStart))
@@ -558,7 +558,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks)
// Load the kernel state.
kernelStart := time.Now()
var stats state.Stats
- if err := state.Load(r, k, &stats); err != nil {
+ if err := state.Load(k.SupervisorContext(), r, k, &stats); err != nil {
return err
}
log.Infof("Kernel load stats: %s", &stats)
@@ -566,7 +566,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks)
// Load the memory file's state.
memoryStart := time.Now()
- if err := k.mf.LoadFrom(r); err != nil {
+ if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil {
return err
}
log.Infof("Memory load took [%s].", time.Since(memoryStart))
@@ -804,8 +804,21 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
// Create a fresh task context.
remainingTraversals = uint(args.MaxSymlinkTraversals)
+ loadArgs := loader.LoadArgs{
+ Mounts: mounts,
+ Root: root,
+ WorkingDirectory: wd,
+ RemainingTraversals: &remainingTraversals,
+ ResolveFinal: true,
+ Filename: args.Filename,
+ File: args.File,
+ CloseOnExec: false,
+ Argv: args.Argv,
+ Envv: args.Envv,
+ Features: k.featureSet,
+ }
- tc, se := k.LoadTaskImage(ctx, mounts, root, wd, &remainingTraversals, args.Filename, args.File, args.Argv, args.Envv, k.featureSet)
+ tc, se := k.LoadTaskImage(ctx, loadArgs)
if se != nil {
return nil, 0, errors.New(se.String())
}
@@ -828,9 +841,11 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
AbstractSocketNamespace: args.AbstractSocketNamespace,
ContainerID: args.ContainerID,
}
- if _, err := k.tasks.NewTask(config); err != nil {
+ t, err := k.tasks.NewTask(config)
+ if err != nil {
return nil, 0, err
}
+ t.traceExecEvent(tc) // Simulate exec for tracing.
// Success.
tgid := k.tasks.Root.IDOfThreadGroup(tg)
@@ -1105,6 +1120,22 @@ func (k *Kernel) SendContainerSignal(cid string, info *arch.SignalInfo) error {
return lastErr
}
+// RebuildTraceContexts rebuilds the trace context for all tasks.
+//
+// Unfortunately, if these are built while tracing is not enabled, then we will
+// not have meaningful trace data. Rebuilding here ensures that we can do so
+// after tracing has been enabled.
+func (k *Kernel) RebuildTraceContexts() {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ k.tasks.mu.RLock()
+ defer k.tasks.mu.RUnlock()
+
+ for t, tid := range k.tasks.Root.tids {
+ t.rebuildTraceContext(tid)
+ }
+}
+
// FeatureSet returns the FeatureSet.
func (k *Kernel) FeatureSet() *cpuid.FeatureSet {
return k.featureSet
@@ -1309,6 +1340,7 @@ func (k *Kernel) ListSockets() []*SocketEntry {
return socks
}
+// supervisorContext is a privileged context.
type supervisorContext struct {
context.NoopSleeper
log.Logger
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
index cde647139..9d34f6d4d 100644
--- a/pkg/sentry/kernel/pipe/BUILD
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -24,8 +24,10 @@ go_library(
"device.go",
"node.go",
"pipe.go",
+ "pipe_util.go",
"reader.go",
"reader_writer.go",
+ "vfs.go",
"writer.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/pipe",
@@ -40,6 +42,7 @@ go_library(
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/safemem",
"//pkg/sentry/usermem",
+ "//pkg/sentry/vfs",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go
index a2dc72204..4a19ab7ce 100644
--- a/pkg/sentry/kernel/pipe/node.go
+++ b/pkg/sentry/kernel/pipe/node.go
@@ -18,7 +18,6 @@ import (
"sync"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -91,10 +90,10 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
switch {
case flags.Read && !flags.Write: // O_RDONLY.
r := i.p.Open(ctx, d, flags)
- i.newHandleLocked(&i.rWakeup)
+ newHandleLocked(&i.rWakeup)
if i.p.isNamed && !flags.NonBlocking && !i.p.HasWriters() {
- if !i.waitFor(&i.wWakeup, ctx) {
+ if !waitFor(&i.mu, &i.wWakeup, ctx) {
r.DecRef()
return nil, syserror.ErrInterrupted
}
@@ -107,7 +106,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
case flags.Write && !flags.Read: // O_WRONLY.
w := i.p.Open(ctx, d, flags)
- i.newHandleLocked(&i.wWakeup)
+ newHandleLocked(&i.wWakeup)
if i.p.isNamed && !i.p.HasReaders() {
// On a nonblocking, write-only open, the open fails with ENXIO if the
@@ -117,7 +116,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
return nil, syserror.ENXIO
}
- if !i.waitFor(&i.rWakeup, ctx) {
+ if !waitFor(&i.mu, &i.rWakeup, ctx) {
w.DecRef()
return nil, syserror.ErrInterrupted
}
@@ -127,8 +126,8 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
case flags.Read && flags.Write: // O_RDWR.
// Pipes opened for read-write always succeeds without blocking.
rw := i.p.Open(ctx, d, flags)
- i.newHandleLocked(&i.rWakeup)
- i.newHandleLocked(&i.wWakeup)
+ newHandleLocked(&i.rWakeup)
+ newHandleLocked(&i.wWakeup)
return rw, nil
default:
@@ -136,65 +135,6 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
}
}
-// waitFor blocks until the underlying pipe has at least one reader/writer is
-// announced via 'wakeupChan', or until 'sleeper' is cancelled. Any call to this
-// function will block for either readers or writers, depending on where
-// 'wakeupChan' points.
-//
-// f.mu must be held by the caller. waitFor returns with f.mu held, but it will
-// drop f.mu before blocking for any reader/writers.
-func (i *inodeOperations) waitFor(wakeupChan *chan struct{}, sleeper amutex.Sleeper) bool {
- // Ideally this function would simply use a condition variable. However, the
- // wait needs to be interruptible via 'sleeper', so we must sychronize via a
- // channel. The synchronization below relies on the fact that closing a
- // channel unblocks all receives on the channel.
-
- // Does an appropriate wakeup channel already exist? If not, create a new
- // one. This is all done under f.mu to avoid races.
- if *wakeupChan == nil {
- *wakeupChan = make(chan struct{})
- }
-
- // Grab a local reference to the wakeup channel since it may disappear as
- // soon as we drop f.mu.
- wakeup := *wakeupChan
-
- // Drop the lock and prepare to sleep.
- i.mu.Unlock()
- cancel := sleeper.SleepStart()
-
- // Wait for either a new reader/write to be signalled via 'wakeup', or
- // for the sleep to be cancelled.
- select {
- case <-wakeup:
- sleeper.SleepFinish(true)
- case <-cancel:
- sleeper.SleepFinish(false)
- }
-
- // Take the lock and check if we were woken. If we were woken and
- // interrupted, the former takes priority.
- i.mu.Lock()
- select {
- case <-wakeup:
- return true
- default:
- return false
- }
-}
-
-// newHandleLocked signals a new pipe reader or writer depending on where
-// 'wakeupChan' points. This unblocks any corresponding reader or writer
-// waiting for the other end of the channel to be opened, see Fifo.waitFor.
-//
-// i.mu must be held.
-func (*inodeOperations) newHandleLocked(wakeupChan *chan struct{}) {
- if *wakeupChan != nil {
- close(*wakeupChan)
- *wakeupChan = nil
- }
-}
-
func (*inodeOperations) Allocate(_ context.Context, _ *fs.Inode, _, _ int64) error {
return syserror.EPIPE
}
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 8e4e8e82e..1a1b38f83 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -111,11 +111,27 @@ func NewPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *Pipe {
if atomicIOBytes > sizeBytes {
atomicIOBytes = sizeBytes
}
- return &Pipe{
- isNamed: isNamed,
- max: sizeBytes,
- atomicIOBytes: atomicIOBytes,
+ var p Pipe
+ initPipe(&p, isNamed, sizeBytes, atomicIOBytes)
+ return &p
+}
+
+func initPipe(pipe *Pipe, isNamed bool, sizeBytes, atomicIOBytes int64) {
+ if sizeBytes < MinimumPipeSize {
+ sizeBytes = MinimumPipeSize
+ }
+ if sizeBytes > MaximumPipeSize {
+ sizeBytes = MaximumPipeSize
+ }
+ if atomicIOBytes <= 0 {
+ atomicIOBytes = 1
+ }
+ if atomicIOBytes > sizeBytes {
+ atomicIOBytes = sizeBytes
}
+ pipe.isNamed = isNamed
+ pipe.max = sizeBytes
+ pipe.atomicIOBytes = atomicIOBytes
}
// NewConnectedPipe initializes a pipe and returns a pair of objects
diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go
new file mode 100644
index 000000000..ef9641e6a
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/pipe_util.go
@@ -0,0 +1,213 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "io"
+ "math"
+ "sync"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/amutex"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// This file contains Pipe file functionality that is tied to neither VFS nor
+// the old fs architecture.
+
+// Release cleans up the pipe's state.
+func (p *Pipe) Release() {
+ p.rClose()
+ p.wClose()
+
+ // Wake up readers and writers.
+ p.Notify(waiter.EventIn | waiter.EventOut)
+}
+
+// Read reads from the Pipe into dst.
+func (p *Pipe) Read(ctx context.Context, dst usermem.IOSequence) (int64, error) {
+ n, err := p.read(ctx, readOps{
+ left: func() int64 {
+ return dst.NumBytes()
+ },
+ limit: func(l int64) {
+ dst = dst.TakeFirst64(l)
+ },
+ read: func(buf *buffer) (int64, error) {
+ n, err := dst.CopyOutFrom(ctx, buf)
+ dst = dst.DropFirst64(n)
+ return n, err
+ },
+ })
+ if n > 0 {
+ p.Notify(waiter.EventOut)
+ }
+ return n, err
+}
+
+// WriteTo writes to w from the Pipe.
+func (p *Pipe) WriteTo(ctx context.Context, w io.Writer, count int64, dup bool) (int64, error) {
+ ops := readOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ read: func(buf *buffer) (int64, error) {
+ n, err := buf.ReadToWriter(w, count, dup)
+ count -= n
+ return n, err
+ },
+ }
+ if dup {
+ // There is no notification for dup operations.
+ return p.dup(ctx, ops)
+ }
+ n, err := p.read(ctx, ops)
+ if n > 0 {
+ p.Notify(waiter.EventOut)
+ }
+ return n, err
+}
+
+// Write writes to the Pipe from src.
+func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error) {
+ n, err := p.write(ctx, writeOps{
+ left: func() int64 {
+ return src.NumBytes()
+ },
+ limit: func(l int64) {
+ src = src.TakeFirst64(l)
+ },
+ write: func(buf *buffer) (int64, error) {
+ n, err := src.CopyInTo(ctx, buf)
+ src = src.DropFirst64(n)
+ return n, err
+ },
+ })
+ if n > 0 {
+ p.Notify(waiter.EventIn)
+ }
+ return n, err
+}
+
+// ReadFrom reads from r to the Pipe.
+func (p *Pipe) ReadFrom(ctx context.Context, r io.Reader, count int64) (int64, error) {
+ n, err := p.write(ctx, writeOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ write: func(buf *buffer) (int64, error) {
+ n, err := buf.WriteFromReader(r, count)
+ count -= n
+ return n, err
+ },
+ })
+ if n > 0 {
+ p.Notify(waiter.EventIn)
+ }
+ return n, err
+}
+
+// Readiness returns the ready events in the underlying pipe.
+func (p *Pipe) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return p.rwReadiness() & mask
+}
+
+// Ioctl implements ioctls on the Pipe.
+func (p *Pipe) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ // Switch on ioctl request.
+ switch int(args[1].Int()) {
+ case linux.FIONREAD:
+ v := p.queued()
+ if v > math.MaxInt32 {
+ v = math.MaxInt32 // Silently truncate.
+ }
+ // Copy result to user-space.
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+ default:
+ return 0, syscall.ENOTTY
+ }
+}
+
+// waitFor blocks until the underlying pipe has at least one reader/writer is
+// announced via 'wakeupChan', or until 'sleeper' is cancelled. Any call to this
+// function will block for either readers or writers, depending on where
+// 'wakeupChan' points.
+//
+// mu must be held by the caller. waitFor returns with mu held, but it will
+// drop mu before blocking for any reader/writers.
+func waitFor(mu *sync.Mutex, wakeupChan *chan struct{}, sleeper amutex.Sleeper) bool {
+ // Ideally this function would simply use a condition variable. However, the
+ // wait needs to be interruptible via 'sleeper', so we must sychronize via a
+ // channel. The synchronization below relies on the fact that closing a
+ // channel unblocks all receives on the channel.
+
+ // Does an appropriate wakeup channel already exist? If not, create a new
+ // one. This is all done under f.mu to avoid races.
+ if *wakeupChan == nil {
+ *wakeupChan = make(chan struct{})
+ }
+
+ // Grab a local reference to the wakeup channel since it may disappear as
+ // soon as we drop f.mu.
+ wakeup := *wakeupChan
+
+ // Drop the lock and prepare to sleep.
+ mu.Unlock()
+ cancel := sleeper.SleepStart()
+
+ // Wait for either a new reader/write to be signalled via 'wakeup', or
+ // for the sleep to be cancelled.
+ select {
+ case <-wakeup:
+ sleeper.SleepFinish(true)
+ case <-cancel:
+ sleeper.SleepFinish(false)
+ }
+
+ // Take the lock and check if we were woken. If we were woken and
+ // interrupted, the former takes priority.
+ mu.Lock()
+ select {
+ case <-wakeup:
+ return true
+ default:
+ return false
+ }
+}
+
+// newHandleLocked signals a new pipe reader or writer depending on where
+// 'wakeupChan' points. This unblocks any corresponding reader or writer
+// waiting for the other end of the channel to be opened, see Fifo.waitFor.
+//
+// Precondition: the mutex protecting wakeupChan must be held.
+func newHandleLocked(wakeupChan *chan struct{}) {
+ if *wakeupChan != nil {
+ close(*wakeupChan)
+ *wakeupChan = nil
+ }
+}
diff --git a/pkg/sentry/kernel/pipe/reader_writer.go b/pkg/sentry/kernel/pipe/reader_writer.go
index 7c307f013..b4d29fc77 100644
--- a/pkg/sentry/kernel/pipe/reader_writer.go
+++ b/pkg/sentry/kernel/pipe/reader_writer.go
@@ -16,16 +16,12 @@ package pipe
import (
"io"
- "math"
- "syscall"
- "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/waiter"
)
// ReaderWriter satisfies the FileOperations interface and services both
@@ -45,124 +41,27 @@ type ReaderWriter struct {
*Pipe
}
-// Release implements fs.FileOperations.Release.
-func (rw *ReaderWriter) Release() {
- rw.Pipe.rClose()
- rw.Pipe.wClose()
-
- // Wake up readers and writers.
- rw.Pipe.Notify(waiter.EventIn | waiter.EventOut)
-}
-
// Read implements fs.FileOperations.Read.
func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
- n, err := rw.Pipe.read(ctx, readOps{
- left: func() int64 {
- return dst.NumBytes()
- },
- limit: func(l int64) {
- dst = dst.TakeFirst64(l)
- },
- read: func(buf *buffer) (int64, error) {
- n, err := dst.CopyOutFrom(ctx, buf)
- dst = dst.DropFirst64(n)
- return n, err
- },
- })
- if n > 0 {
- rw.Pipe.Notify(waiter.EventOut)
- }
- return n, err
+ return rw.Pipe.Read(ctx, dst)
}
// WriteTo implements fs.FileOperations.WriteTo.
func (rw *ReaderWriter) WriteTo(ctx context.Context, _ *fs.File, w io.Writer, count int64, dup bool) (int64, error) {
- ops := readOps{
- left: func() int64 {
- return count
- },
- limit: func(l int64) {
- count = l
- },
- read: func(buf *buffer) (int64, error) {
- n, err := buf.ReadToWriter(w, count, dup)
- count -= n
- return n, err
- },
- }
- if dup {
- // There is no notification for dup operations.
- return rw.Pipe.dup(ctx, ops)
- }
- n, err := rw.Pipe.read(ctx, ops)
- if n > 0 {
- rw.Pipe.Notify(waiter.EventOut)
- }
- return n, err
+ return rw.Pipe.WriteTo(ctx, w, count, dup)
}
// Write implements fs.FileOperations.Write.
func (rw *ReaderWriter) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
- n, err := rw.Pipe.write(ctx, writeOps{
- left: func() int64 {
- return src.NumBytes()
- },
- limit: func(l int64) {
- src = src.TakeFirst64(l)
- },
- write: func(buf *buffer) (int64, error) {
- n, err := src.CopyInTo(ctx, buf)
- src = src.DropFirst64(n)
- return n, err
- },
- })
- if n > 0 {
- rw.Pipe.Notify(waiter.EventIn)
- }
- return n, err
+ return rw.Pipe.Write(ctx, src)
}
// ReadFrom implements fs.FileOperations.WriteTo.
func (rw *ReaderWriter) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
- n, err := rw.Pipe.write(ctx, writeOps{
- left: func() int64 {
- return count
- },
- limit: func(l int64) {
- count = l
- },
- write: func(buf *buffer) (int64, error) {
- n, err := buf.WriteFromReader(r, count)
- count -= n
- return n, err
- },
- })
- if n > 0 {
- rw.Pipe.Notify(waiter.EventIn)
- }
- return n, err
-}
-
-// Readiness returns the ready events in the underlying pipe.
-func (rw *ReaderWriter) Readiness(mask waiter.EventMask) waiter.EventMask {
- return rw.Pipe.rwReadiness() & mask
+ return rw.Pipe.ReadFrom(ctx, r, count)
}
// Ioctl implements fs.FileOperations.Ioctl.
func (rw *ReaderWriter) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
- // Switch on ioctl request.
- switch int(args[1].Int()) {
- case linux.FIONREAD:
- v := rw.queued()
- if v > math.MaxInt32 {
- v = math.MaxInt32 // Silently truncate.
- }
- // Copy result to user-space.
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
- AddressSpaceActive: true,
- })
- return 0, err
- default:
- return 0, syscall.ENOTTY
- }
+ return rw.Pipe.Ioctl(ctx, io, args)
}
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
new file mode 100644
index 000000000..6416e0dd8
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -0,0 +1,220 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// This file contains types enabling the pipe package to be used with the vfs
+// package.
+
+// VFSPipe represents the actual pipe, analagous to an inode. VFSPipes should
+// not be copied.
+type VFSPipe struct {
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // pipe is the underlying pipe.
+ pipe Pipe
+
+ // Channels for synchronizing the creation of new readers and writers
+ // of this fifo. See waitFor and newHandleLocked.
+ //
+ // These are not saved/restored because all waiters are unblocked on
+ // save, and either automatically restart (via ERESTARTSYS) or return
+ // EINTR on resume. On restarts via ERESTARTSYS, the appropriate
+ // channel will be recreated.
+ rWakeup chan struct{} `state:"nosave"`
+ wWakeup chan struct{} `state:"nosave"`
+}
+
+// NewVFSPipe returns an initialized VFSPipe.
+func NewVFSPipe(sizeBytes, atomicIOBytes int64) *VFSPipe {
+ var vp VFSPipe
+ initPipe(&vp.pipe, true /* isNamed */, sizeBytes, atomicIOBytes)
+ return &vp
+}
+
+// NewVFSPipeFD opens a named pipe. Named pipes have special blocking semantics
+// during open:
+//
+// "Normally, opening the FIFO blocks until the other end is opened also. A
+// process can open a FIFO in nonblocking mode. In this case, opening for
+// read-only will succeed even if no-one has opened on the write side yet,
+// opening for write-only will fail with ENXIO (no such device or address)
+// unless the other end has already been opened. Under Linux, opening a FIFO
+// for read and write will succeed both in blocking and nonblocking mode. POSIX
+// leaves this behavior undefined. This can be used to open a FIFO for writing
+// while there are no readers available." - fifo(7)
+func (vp *VFSPipe) NewVFSPipeFD(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, vfsfd *vfs.FileDescription, flags uint32) (*VFSPipeFD, error) {
+ vp.mu.Lock()
+ defer vp.mu.Unlock()
+
+ readable := vfs.MayReadFileWithOpenFlags(flags)
+ writable := vfs.MayWriteFileWithOpenFlags(flags)
+ if !readable && !writable {
+ return nil, syserror.EINVAL
+ }
+
+ vfd, err := vp.open(rp, vfsd, vfsfd, flags)
+ if err != nil {
+ return nil, err
+ }
+
+ switch {
+ case readable && writable:
+ // Pipes opened for read-write always succeed without blocking.
+ newHandleLocked(&vp.rWakeup)
+ newHandleLocked(&vp.wWakeup)
+
+ case readable:
+ newHandleLocked(&vp.rWakeup)
+ // If this pipe is being opened as nonblocking and there's no
+ // writer, we have to wait for a writer to open the other end.
+ if flags&linux.O_NONBLOCK == 0 && !vp.pipe.HasWriters() && !waitFor(&vp.mu, &vp.wWakeup, ctx) {
+ return nil, syserror.EINTR
+ }
+
+ case writable:
+ newHandleLocked(&vp.wWakeup)
+
+ if !vp.pipe.HasReaders() {
+ // Nonblocking, write-only opens fail with ENXIO when
+ // the read side isn't open yet.
+ if flags&linux.O_NONBLOCK != 0 {
+ return nil, syserror.ENXIO
+ }
+ // Wait for a reader to open the other end.
+ if !waitFor(&vp.mu, &vp.rWakeup, ctx) {
+ return nil, syserror.EINTR
+ }
+ }
+
+ default:
+ panic("invalid pipe flags: must be readable, writable, or both")
+ }
+
+ return vfd, nil
+}
+
+// Preconditions: vp.mu must be held.
+func (vp *VFSPipe) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, vfsfd *vfs.FileDescription, flags uint32) (*VFSPipeFD, error) {
+ var fd VFSPipeFD
+ fd.flags = flags
+ fd.readable = vfs.MayReadFileWithOpenFlags(flags)
+ fd.writable = vfs.MayWriteFileWithOpenFlags(flags)
+ fd.vfsfd = vfsfd
+ fd.pipe = &vp.pipe
+ if fd.writable {
+ // The corresponding Mount.EndWrite() is in VFSPipe.Release().
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return nil, err
+ }
+ }
+
+ switch {
+ case fd.readable && fd.writable:
+ vp.pipe.rOpen()
+ vp.pipe.wOpen()
+ case fd.readable:
+ vp.pipe.rOpen()
+ case fd.writable:
+ vp.pipe.wOpen()
+ default:
+ panic("invalid pipe flags: must be readable, writable, or both")
+ }
+
+ return &fd, nil
+}
+
+// VFSPipeFD implements a subset of vfs.FileDescriptionImpl for pipes. It is
+// expected that filesystesm will use this in a struct implementing
+// vfs.FileDescriptionImpl.
+type VFSPipeFD struct {
+ pipe *Pipe
+ flags uint32
+ readable bool
+ writable bool
+ vfsfd *vfs.FileDescription
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *VFSPipeFD) Release() {
+ var event waiter.EventMask
+ if fd.readable {
+ fd.pipe.rClose()
+ event |= waiter.EventIn
+ }
+ if fd.writable {
+ fd.pipe.wClose()
+ event |= waiter.EventOut
+ }
+ if event == 0 {
+ panic("invalid pipe flags: must be readable, writable, or both")
+ }
+
+ if fd.writable {
+ fd.vfsfd.VirtualDentry().Mount().EndWrite()
+ }
+
+ fd.pipe.Notify(event)
+}
+
+// OnClose implements vfs.FileDescriptionImpl.OnClose.
+func (fd *VFSPipeFD) OnClose(_ context.Context) error {
+ return nil
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *VFSPipeFD) PRead(_ context.Context, _ usermem.IOSequence, _ int64, _ vfs.ReadOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *VFSPipeFD) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) {
+ if !fd.readable {
+ return 0, syserror.EINVAL
+ }
+
+ return fd.pipe.Read(ctx, dst)
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *VFSPipeFD) PWrite(_ context.Context, _ usermem.IOSequence, _ int64, _ vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *VFSPipeFD) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) {
+ if !fd.writable {
+ return 0, syserror.EINVAL
+ }
+
+ return fd.pipe.Write(ctx, src)
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
+func (fd *VFSPipeFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return fd.pipe.Ioctl(ctx, uio, args)
+}
diff --git a/pkg/sentry/kernel/ptrace_arm64.go b/pkg/sentry/kernel/ptrace_arm64.go
index 0acdf769d..61e412911 100644
--- a/pkg/sentry/kernel/ptrace_arm64.go
+++ b/pkg/sentry/kernel/ptrace_arm64.go
@@ -17,7 +17,6 @@
package kernel
import (
- "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go
index 93fe68a3e..de9617e9d 100644
--- a/pkg/sentry/kernel/semaphore/semaphore.go
+++ b/pkg/sentry/kernel/semaphore/semaphore.go
@@ -302,7 +302,7 @@ func (s *Set) SetVal(ctx context.Context, num int32, val int16, creds *auth.Cred
return syserror.ERANGE
}
- // TODO(b/29354920): Clear undo entries in all processes
+ // TODO(gvisor.dev/issue/137): Clear undo entries in all processes.
sem.value = val
sem.pid = pid
s.changeTime = ktime.NowFromContext(ctx)
@@ -336,7 +336,7 @@ func (s *Set) SetValAll(ctx context.Context, vals []uint16, creds *auth.Credenti
for i, val := range vals {
sem := &s.sems[i]
- // TODO(b/29354920): Clear undo entries in all processes
+ // TODO(gvisor.dev/issue/137): Clear undo entries in all processes.
sem.value = int16(val)
sem.pid = pid
sem.wakeWaiters()
@@ -481,7 +481,7 @@ func (s *Set) executeOps(ctx context.Context, ops []linux.Sembuf, pid int32) (ch
}
// All operations succeeded, apply them.
- // TODO(b/29354920): handle undo operations.
+ // TODO(gvisor.dev/issue/137): handle undo operations.
for i, v := range tmpVals {
s.sems[i].value = v
s.sems[i].wakeWaiters()
diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD
index 50b69d154..9f7e19b4d 100644
--- a/pkg/sentry/kernel/signalfd/BUILD
+++ b/pkg/sentry/kernel/signalfd/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "signalfd",
srcs = ["signalfd.go"],
diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go
index 220fa73a2..2fdee0282 100644
--- a/pkg/sentry/kernel/syscalls.go
+++ b/pkg/sentry/kernel/syscalls.go
@@ -339,6 +339,14 @@ func (s *SyscallTable) Lookup(sysno uintptr) SyscallFn {
return nil
}
+// LookupName looks up a syscall name.
+func (s *SyscallTable) LookupName(sysno uintptr) string {
+ if sc, ok := s.Table[sysno]; ok {
+ return sc.Name
+ }
+ return fmt.Sprintf("sys_%d", sysno) // Unlikely.
+}
+
// LookupEmulate looks up an emulation syscall number.
func (s *SyscallTable) LookupEmulate(addr usermem.Addr) (uintptr, bool) {
sysno, ok := s.Emulate[addr]
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index c82ef5486..ab0c6c4aa 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -15,6 +15,8 @@
package kernel
import (
+ gocontext "context"
+ "runtime/trace"
"sync"
"sync/atomic"
@@ -35,8 +37,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/syncutil"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/third_party/gvsync"
)
// Task represents a thread of execution in the untrusted app. It
@@ -83,7 +85,7 @@ type Task struct {
//
// gosched is protected by goschedSeq. gosched is owned by the task
// goroutine.
- goschedSeq gvsync.SeqCount `state:"nosave"`
+ goschedSeq syncutil.SeqCount `state:"nosave"`
gosched TaskGoroutineSchedInfo
// yieldCount is the number of times the task goroutine has called
@@ -390,7 +392,14 @@ type Task struct {
// logPrefix is a string containing the task's thread ID in the root PID
// namespace, and is prepended to log messages emitted by Task.Infof etc.
- logPrefix atomic.Value `state:".(string)"`
+ logPrefix atomic.Value `state:"nosave"`
+
+ // traceContext and traceTask are both used for tracing, and are
+ // updated along with the logPrefix in updateInfoLocked.
+ //
+ // These are exclusive to the task goroutine.
+ traceContext gocontext.Context `state:"nosave"`
+ traceTask *trace.Task `state:"nosave"`
// creds is the task's credentials.
//
@@ -528,14 +537,6 @@ func (t *Task) loadPtraceTracer(tracer *Task) {
t.ptraceTracer.Store(tracer)
}
-func (t *Task) saveLogPrefix() string {
- return t.logPrefix.Load().(string)
-}
-
-func (t *Task) loadLogPrefix(prefix string) {
- t.logPrefix.Store(prefix)
-}
-
func (t *Task) saveSyscallFilters() []bpf.Program {
if f := t.syscallFilters.Load(); f != nil {
return f.([]bpf.Program)
@@ -549,6 +550,7 @@ func (t *Task) loadSyscallFilters(filters []bpf.Program) {
// afterLoad is invoked by stateify.
func (t *Task) afterLoad() {
+ t.updateInfoLocked()
t.interruptChan = make(chan struct{}, 1)
t.gosched.State = TaskGoroutineNonexistent
if t.stop != nil {
@@ -709,9 +711,9 @@ func (t *Task) FDTable() *FDTable {
return t.fdTable
}
-// GetFile is a convenience wrapper t.FDTable().GetFile.
+// GetFile is a convenience wrapper for t.FDTable().Get.
//
-// Precondition: same as FDTable.
+// Precondition: same as FDTable.Get.
func (t *Task) GetFile(fd int32) *fs.File {
f, _ := t.fdTable.Get(fd)
return f
diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go
index dd69939f9..4a4a69ee2 100644
--- a/pkg/sentry/kernel/task_block.go
+++ b/pkg/sentry/kernel/task_block.go
@@ -16,6 +16,7 @@ package kernel
import (
"runtime"
+ "runtime/trace"
"time"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -133,19 +134,24 @@ func (t *Task) block(C <-chan struct{}, timerChan <-chan struct{}) error {
runtime.Gosched()
}
+ region := trace.StartRegion(t.traceContext, blockRegion)
select {
case <-C:
+ region.End()
t.SleepFinish(true)
+ // Woken by event.
return nil
case <-interrupt:
+ region.End()
t.SleepFinish(false)
// Return the indicated error on interrupt.
return syserror.ErrInterrupted
case <-timerChan:
- // We've timed out.
+ region.End()
t.SleepFinish(true)
+ // We've timed out.
return syserror.ETIMEDOUT
}
}
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index 0916fd658..3eadfedb4 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -299,6 +299,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
// nt that it must receive before its task goroutine starts running.
tid := nt.k.tasks.Root.IDOfTask(nt)
defer nt.Start(tid)
+ t.traceCloneEvent(tid)
// "If fork/clone and execve are allowed by @prog, any child processes will
// be constrained to the same filters and system call ABI as the parent." -
diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go
index 8639d379f..bb5560acf 100644
--- a/pkg/sentry/kernel/task_context.go
+++ b/pkg/sentry/kernel/task_context.go
@@ -18,10 +18,8 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/sentry/loader"
"gvisor.dev/gvisor/pkg/sentry/mm"
@@ -132,30 +130,21 @@ func (t *Task) Stack() *arch.Stack {
return &arch.Stack{t.Arch(), t.MemoryManager(), usermem.Addr(t.Arch().Stack())}
}
-// LoadTaskImage loads filename into a new TaskContext.
+// LoadTaskImage loads a specified file into a new TaskContext.
//
-// It takes several arguments:
-// * mounts: MountNamespace to lookup filename in
-// * root: Root to lookup filename under
-// * wd: Working directory to lookup filename under
-// * maxTraversals: maximum number of symlinks to follow
-// * filename: path to binary to load
-// * file: an open fs.File object of the binary to load. If set,
-// file will be loaded and not filename.
-// * argv: Binary argv
-// * envv: Binary envv
-// * fs: Binary FeatureSet
-func (k *Kernel) LoadTaskImage(ctx context.Context, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, filename string, file *fs.File, argv, envv []string, fs *cpuid.FeatureSet) (*TaskContext, *syserr.Error) {
- // If File is not nil, we should load that instead of resolving filename.
- if file != nil {
- filename = file.MappedName(ctx)
+// args.MemoryManager does not need to be set by the caller.
+func (k *Kernel) LoadTaskImage(ctx context.Context, args loader.LoadArgs) (*TaskContext, *syserr.Error) {
+ // If File is not nil, we should load that instead of resolving Filename.
+ if args.File != nil {
+ args.Filename = args.File.MappedName(ctx)
}
// Prepare a new user address space to load into.
m := mm.NewMemoryManager(k, k)
defer m.DecUsers(ctx)
+ args.MemoryManager = m
- os, ac, name, err := loader.Load(ctx, m, mounts, root, wd, maxTraversals, fs, filename, file, argv, envv, k.extraAuxv, k.vdso)
+ os, ac, name, err := loader.Load(ctx, args, k.extraAuxv, k.vdso)
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
index 17a089b90..90a6190f1 100644
--- a/pkg/sentry/kernel/task_exec.go
+++ b/pkg/sentry/kernel/task_exec.go
@@ -129,6 +129,7 @@ type runSyscallAfterExecStop struct {
}
func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
+ t.traceExecEvent(r.tc)
t.tg.pidns.owner.mu.Lock()
t.tg.execing = nil
if t.killed() {
@@ -253,7 +254,7 @@ func (t *Task) promoteLocked() {
t.tg.leader = t
t.Infof("Becoming TID %d (in root PID namespace)", t.tg.pidns.owner.Root.tids[t])
- t.updateLogPrefixLocked()
+ t.updateInfoLocked()
// Reap the original leader. If it has a tracer, detach it instead of
// waiting for it to acknowledge the original leader's death.
oldLeader.exitParentNotified = true
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index 535f03e50..435761e5a 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -236,6 +236,7 @@ func (*runExit) execute(t *Task) taskRunState {
type runExitMain struct{}
func (*runExitMain) execute(t *Task) taskRunState {
+ t.traceExitEvent()
lastExiter := t.exitThreadGroup()
// If the task has a cleartid, and the thread group wasn't killed by a
diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go
index a29e9b9eb..0fb3661de 100644
--- a/pkg/sentry/kernel/task_log.go
+++ b/pkg/sentry/kernel/task_log.go
@@ -16,6 +16,7 @@ package kernel
import (
"fmt"
+ "runtime/trace"
"sort"
"gvisor.dev/gvisor/pkg/log"
@@ -127,11 +128,88 @@ func (t *Task) debugDumpStack() {
}
}
-// updateLogPrefix updates the task's cached log prefix to reflect its
-// current thread ID.
+// trace definitions.
+//
+// Note that all region names are prefixed by ':' in order to ensure that they
+// are lexically ordered before all system calls, which use the naked system
+// call name (e.g. "read") for maximum clarity.
+const (
+ traceCategory = "task"
+ runRegion = ":run"
+ blockRegion = ":block"
+ cpuidRegion = ":cpuid"
+ faultRegion = ":fault"
+)
+
+// updateInfoLocked updates the task's cached log prefix and tracing
+// information to reflect its current thread ID.
//
// Preconditions: The task's owning TaskSet.mu must be locked.
-func (t *Task) updateLogPrefixLocked() {
+func (t *Task) updateInfoLocked() {
// Use the task's TID in the root PID namespace for logging.
- t.logPrefix.Store(fmt.Sprintf("[% 4d] ", t.tg.pidns.owner.Root.tids[t]))
+ tid := t.tg.pidns.owner.Root.tids[t]
+ t.logPrefix.Store(fmt.Sprintf("[% 4d] ", tid))
+ t.rebuildTraceContext(tid)
+}
+
+// rebuildTraceContext rebuilds the trace context.
+//
+// Precondition: the passed tid must be the tid in the root namespace.
+func (t *Task) rebuildTraceContext(tid ThreadID) {
+ // Re-initialize the trace context.
+ if t.traceTask != nil {
+ t.traceTask.End()
+ }
+
+ // Note that we define the "task type" to be the dynamic TID. This does
+ // not align perfectly with the documentation for "tasks" in the
+ // tracing package. Tasks may be assumed to be bounded by analysis
+ // tools. However, if we just use a generic "task" type here, then the
+ // "user-defined tasks" page on the tracing dashboard becomes nearly
+ // unusable, as it loads all traces from all tasks.
+ //
+ // We can assume that the number of tasks in the system is not
+ // arbitrarily large (in general it won't be, especially for cases
+ // where we're collecting a brief profile), so using the TID is a
+ // reasonable compromise in this case.
+ t.traceContext, t.traceTask = trace.NewTask(t, fmt.Sprintf("tid:%d", tid))
+}
+
+// traceCloneEvent is called when a new task is spawned.
+//
+// ntid must be the new task's ThreadID in the root namespace.
+func (t *Task) traceCloneEvent(ntid ThreadID) {
+ if !trace.IsEnabled() {
+ return
+ }
+ trace.Logf(t.traceContext, traceCategory, "spawn: %d", ntid)
+}
+
+// traceExitEvent is called when a task exits.
+func (t *Task) traceExitEvent() {
+ if !trace.IsEnabled() {
+ return
+ }
+ trace.Logf(t.traceContext, traceCategory, "exit status: 0x%x", t.exitStatus.Status())
+}
+
+// traceExecEvent is called when a task calls exec.
+func (t *Task) traceExecEvent(tc *TaskContext) {
+ if !trace.IsEnabled() {
+ return
+ }
+ d := tc.MemoryManager.Executable()
+ if d == nil {
+ trace.Logf(t.traceContext, traceCategory, "exec: << unknown >>")
+ return
+ }
+ defer d.DecRef()
+ root := t.fsContext.RootDirectory()
+ if root == nil {
+ trace.Logf(t.traceContext, traceCategory, "exec: << no root directory >>")
+ return
+ }
+ defer root.DecRef()
+ n, _ := d.FullName(root)
+ trace.Logf(t.traceContext, traceCategory, "exec: %s", n)
}
diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go
index c92266c59..d97f8c189 100644
--- a/pkg/sentry/kernel/task_run.go
+++ b/pkg/sentry/kernel/task_run.go
@@ -17,6 +17,7 @@ package kernel
import (
"bytes"
"runtime"
+ "runtime/trace"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -205,9 +206,11 @@ func (*runApp) execute(t *Task) taskRunState {
t.tg.pidns.owner.mu.RUnlock()
}
+ region := trace.StartRegion(t.traceContext, runRegion)
t.accountTaskGoroutineEnter(TaskGoroutineRunningApp)
info, at, err := t.p.Switch(t.MemoryManager().AddressSpace(), t.Arch(), t.rseqCPU)
t.accountTaskGoroutineLeave(TaskGoroutineRunningApp)
+ region.End()
if clearSinglestep {
t.Arch().ClearSingleStep()
@@ -225,6 +228,7 @@ func (*runApp) execute(t *Task) taskRunState {
case platform.ErrContextSignalCPUID:
// Is this a CPUID instruction?
+ region := trace.StartRegion(t.traceContext, cpuidRegion)
expected := arch.CPUIDInstruction[:]
found := make([]byte, len(expected))
_, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found)
@@ -232,10 +236,12 @@ func (*runApp) execute(t *Task) taskRunState {
// Skip the cpuid instruction.
t.Arch().CPUIDEmulate(t)
t.Arch().SetIP(t.Arch().IP() + uintptr(len(expected)))
+ region.End()
// Resume execution.
return (*runApp)(nil)
}
+ region.End() // Not an actual CPUID, but required copy-in.
// The instruction at the given RIP was not a CPUID, and we
// fallthrough to the default signal deliver behavior below.
@@ -251,8 +257,10 @@ func (*runApp) execute(t *Task) taskRunState {
// an application-generated signal and we should continue execution
// normally.
if at.Any() {
+ region := trace.StartRegion(t.traceContext, faultRegion)
addr := usermem.Addr(info.Addr())
err := t.MemoryManager().HandleUserFault(t, addr, at, usermem.Addr(t.Arch().Stack()))
+ region.End()
if err == nil {
// The fault was handled appropriately.
// We can resume running the application.
@@ -260,6 +268,12 @@ func (*runApp) execute(t *Task) taskRunState {
}
// Is this a vsyscall that we need emulate?
+ //
+ // Note that we don't track vsyscalls as part of a
+ // specific trace region. This is because regions don't
+ // stack, and the actual system call will count as a
+ // region. We should be able to easily identify
+ // vsyscalls by having a <fault><syscall> pair.
if at.Execute {
if sysno, ok := t.tc.st.LookupEmulate(addr); ok {
return t.doVsyscall(addr, sysno)
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index ae6fc4025..3522a4ae5 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -154,10 +154,10 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
// Below this point, newTask is expected not to fail (there is no rollback
// of assignTIDsLocked or any of the following).
- // Logging on t's behalf will panic if t.logPrefix hasn't been initialized.
- // This is the earliest point at which we can do so (since t now has thread
- // IDs).
- t.updateLogPrefixLocked()
+ // Logging on t's behalf will panic if t.logPrefix hasn't been
+ // initialized. This is the earliest point at which we can do so
+ // (since t now has thread IDs).
+ t.updateInfoLocked()
if cfg.InheritParent != nil {
t.parent = cfg.InheritParent.parent
diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go
index b543d536a..3180f5560 100644
--- a/pkg/sentry/kernel/task_syscall.go
+++ b/pkg/sentry/kernel/task_syscall.go
@@ -17,6 +17,7 @@ package kernel
import (
"fmt"
"os"
+ "runtime/trace"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -160,6 +161,10 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u
ctrl = ctrlStopAndReinvokeSyscall
} else {
fn := s.Lookup(sysno)
+ var region *trace.Region // Only non-nil if tracing == true.
+ if trace.IsEnabled() {
+ region = trace.StartRegion(t.traceContext, s.LookupName(sysno))
+ }
if fn != nil {
// Call our syscall implementation.
rval, ctrl, err = fn(t, args)
@@ -167,6 +172,9 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u
// Use the missing function if not found.
rval, err = t.SyscallTable().Missing(t, sysno, args)
}
+ if region != nil {
+ region.End()
+ }
}
if bits.IsOn32(fe, ExternalAfterEnable) && (s.ExternalFilterAfter == nil || s.ExternalFilterAfter(t, sysno, args)) {
diff --git a/pkg/sentry/kernel/tty.go b/pkg/sentry/kernel/tty.go
index 34f84487a..048de26dc 100644
--- a/pkg/sentry/kernel/tty.go
+++ b/pkg/sentry/kernel/tty.go
@@ -21,8 +21,19 @@ import "sync"
//
// +stateify savable
type TTY struct {
+ // Index is the terminal index. It is immutable.
+ Index uint32
+
mu sync.Mutex `state:"nosave"`
// tg is protected by mu.
tg *ThreadGroup
}
+
+// TTY returns the thread group's controlling terminal. If nil, there is no
+// controlling terminal.
+func (tg *ThreadGroup) TTY() *TTY {
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
+ return tg.tty
+}
diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go
index 2d9251e92..6299a3e2f 100644
--- a/pkg/sentry/loader/elf.go
+++ b/pkg/sentry/loader/elf.go
@@ -408,6 +408,8 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, info el
start = vaddr
}
if vaddr < end {
+ // NOTE(b/37474556): Linux allows out-of-order
+ // segments, in violation of the spec.
ctx.Infof("PT_LOAD headers out-of-order. %#x < %#x", vaddr, end)
return loadedELF{}, syserror.ENOEXEC
}
@@ -624,15 +626,15 @@ func loadInterpreterELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, in
return loadParsedELF(ctx, m, f, info, 0)
}
-// loadELF loads f into the Task address space.
+// loadELF loads args.File into the Task address space.
//
// If loadELF returns ErrSwitchFile it should be called again with the returned
// path and argv.
//
// Preconditions:
-// * f is an ELF file
-func loadELF(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, fs *cpuid.FeatureSet, f *fs.File) (loadedELF, arch.Context, error) {
- bin, ac, err := loadInitialELF(ctx, m, fs, f)
+// * args.File is an ELF file
+func loadELF(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, error) {
+ bin, ac, err := loadInitialELF(ctx, args.MemoryManager, args.Features, args.File)
if err != nil {
ctx.Infof("Error loading binary: %v", err)
return loadedELF{}, nil, err
@@ -640,7 +642,14 @@ func loadELF(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace
var interp loadedELF
if bin.interpreter != "" {
- d, i, err := openPath(ctx, mounts, root, wd, maxTraversals, bin.interpreter)
+ // Even if we do not allow the final link of the script to be
+ // resolved, the interpreter should still be resolved if it is
+ // a symlink.
+ args.ResolveFinal = true
+ // Refresh the traversal limit.
+ *args.RemainingTraversals = linux.MaxSymlinkTraversals
+ args.Filename = bin.interpreter
+ d, i, err := openPath(ctx, args)
if err != nil {
ctx.Infof("Error opening interpreter %s: %v", bin.interpreter, err)
return loadedELF{}, nil, err
@@ -649,7 +658,7 @@ func loadELF(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace
// We don't need the Dirent.
d.DecRef()
- interp, err = loadInterpreterELF(ctx, m, i, bin)
+ interp, err = loadInterpreterELF(ctx, args.MemoryManager, i, bin)
if err != nil {
ctx.Infof("Error loading interpreter: %v", err)
return loadedELF{}, nil, err
diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go
index 089d1635b..b03eeb005 100644
--- a/pkg/sentry/loader/loader.go
+++ b/pkg/sentry/loader/loader.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package loader loads a binary into a MemoryManager.
+// Package loader loads an executable file into a MemoryManager.
package loader
import (
@@ -20,6 +20,7 @@ import (
"fmt"
"io"
"path"
+ "strings"
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -35,6 +36,54 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// LoadArgs holds specifications for an executable file to be loaded.
+type LoadArgs struct {
+ // MemoryManager is the memory manager to load the executable into.
+ MemoryManager *mm.MemoryManager
+
+ // Mounts is the mount namespace in which to look up Filename.
+ Mounts *fs.MountNamespace
+
+ // Root is the root directory under which to look up Filename.
+ Root *fs.Dirent
+
+ // WorkingDirectory is the working directory under which to look up
+ // Filename.
+ WorkingDirectory *fs.Dirent
+
+ // RemainingTraversals is the maximum number of symlinks to follow to
+ // resolve Filename. This counter is passed by reference to keep it
+ // updated throughout the call stack.
+ RemainingTraversals *uint
+
+ // ResolveFinal indicates whether the final link of Filename should be
+ // resolved, if it is a symlink.
+ ResolveFinal bool
+
+ // Filename is the path for the executable.
+ Filename string
+
+ // File is an open fs.File object of the executable. If File is not
+ // nil, then File will be loaded and Filename will be ignored.
+ File *fs.File
+
+ // CloseOnExec indicates that the executable (or one of its parent
+ // directories) was opened with O_CLOEXEC. If the executable is an
+ // interpreter script, then cause an ENOENT error to occur, since the
+ // script would otherwise be inaccessible to the interpreter.
+ CloseOnExec bool
+
+ // Argv is the vector of arguments to pass to the executable.
+ Argv []string
+
+ // Envv is the vector of environment variables to pass to the
+ // executable.
+ Envv []string
+
+ // Features specifies the CPU feature set for the executable.
+ Features *cpuid.FeatureSet
+}
+
// readFull behaves like io.ReadFull for an *fs.File.
func readFull(ctx context.Context, f *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
var total int64
@@ -51,80 +100,82 @@ func readFull(ctx context.Context, f *fs.File, dst usermem.IOSequence, offset in
return total, nil
}
-// openPath opens name for loading.
+// openPath opens args.Filename and checks that it is valid for loading.
//
-// openPath returns the fs.Dirent and an *fs.File for name, which is not
+// openPath returns an *fs.Dirent and *fs.File for args.Filename, which is not
// installed in the Task FDTable. The caller takes ownership of both.
//
-// name must be a readable, executable, regular file.
-func openPath(ctx context.Context, mm *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, name string) (*fs.Dirent, *fs.File, error) {
- if name == "" {
+// args.Filename must be a readable, executable, regular file.
+func openPath(ctx context.Context, args LoadArgs) (*fs.Dirent, *fs.File, error) {
+ if args.Filename == "" {
ctx.Infof("cannot open empty name")
return nil, nil, syserror.ENOENT
}
- d, err := mm.FindInode(ctx, root, wd, name, maxTraversals)
+ var d *fs.Dirent
+ var err error
+ if args.ResolveFinal {
+ d, err = args.Mounts.FindInode(ctx, args.Root, args.WorkingDirectory, args.Filename, args.RemainingTraversals)
+ } else {
+ d, err = args.Mounts.FindLink(ctx, args.Root, args.WorkingDirectory, args.Filename, args.RemainingTraversals)
+ }
if err != nil {
return nil, nil, err
}
-
- // Open file will take a reference to Dirent, so destroy this one.
+ // Defer a DecRef for the sake of failure cases.
defer d.DecRef()
- return openFile(ctx, nil, d, name)
-}
+ if !args.ResolveFinal && fs.IsSymlink(d.Inode.StableAttr) {
+ return nil, nil, syserror.ELOOP
+ }
-// openFile performs checks on a file to be executed. If provided a *fs.File,
-// openFile takes that file's Dirent and performs checks on it. If provided a
-// *fs.Dirent and not a *fs.File, it creates a *fs.File object from the Dirent's
-// Inode and performs checks on that.
-//
-// openFile returns an *fs.File and *fs.Dirent, and the caller takes ownership
-// of both.
-//
-// "dirent" and "file" must not both be nil and point to a readable, executable, regular file.
-func openFile(ctx context.Context, file *fs.File, dirent *fs.Dirent, name string) (*fs.Dirent, *fs.File, error) {
- // file and dirent must not be nil.
- if dirent == nil && file == nil {
- ctx.Infof("dirent and file cannot both be nil.")
- return nil, nil, syserror.ENOENT
+ if err := checkPermission(ctx, d); err != nil {
+ return nil, nil, err
}
- if file != nil {
- dirent = file.Dirent
+ // If they claim it's a directory, then make sure.
+ //
+ // N.B. we reject directories below, but we must first reject
+ // non-directories passed as directories.
+ if strings.HasSuffix(args.Filename, "/") && !fs.IsDir(d.Inode.StableAttr) {
+ return nil, nil, syserror.ENOTDIR
}
- // Perform permissions checks on the file.
- if err := checkFile(ctx, dirent, name); err != nil {
+ if err := checkIsRegularFile(ctx, d, args.Filename); err != nil {
return nil, nil, err
}
- if file == nil {
- var ferr error
- if file, ferr = dirent.Inode.GetFile(ctx, dirent, fs.FileFlags{Read: true}); ferr != nil {
- return nil, nil, ferr
- }
- } else {
- // GetFile takes a reference to the created file, so make one in the case
- // that the file reference already existed.
- file.IncRef()
+ f, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true})
+ if err != nil {
+ return nil, nil, err
+ }
+ // Defer a DecRef for the sake of failure cases.
+ defer f.DecRef()
+
+ if err := checkPread(ctx, f, args.Filename); err != nil {
+ return nil, nil, err
+ }
+
+ d.IncRef()
+ f.IncRef()
+ return d, f, err
+}
+
+// checkFile performs checks on a file to be executed.
+func checkFile(ctx context.Context, f *fs.File, filename string) error {
+ if err := checkPermission(ctx, f.Dirent); err != nil {
+ return err
}
- // We must be able to read at arbitrary offsets.
- if !file.Flags().Pread {
- file.DecRef()
- ctx.Infof("%s cannot be read at an offset: %+v", file.MappedName(ctx), file.Flags())
- return nil, nil, syserror.EACCES
+ if err := checkIsRegularFile(ctx, f.Dirent, filename); err != nil {
+ return err
}
- // Grab reference for caller.
- dirent.IncRef()
- return dirent, file, nil
+ return checkPread(ctx, f, filename)
}
-// checkFile performs file permissions checks for binaries called in openPath
-// and openFile
-func checkFile(ctx context.Context, d *fs.Dirent, name string) error {
+// checkPermission checks whether the file is readable and executable.
+func checkPermission(ctx context.Context, d *fs.Dirent) error {
perms := fs.PermMask{
// TODO(gvisor.dev/issue/160): Linux requires only execute
// permission, not read. However, our backing filesystems may
@@ -135,26 +186,26 @@ func checkFile(ctx context.Context, d *fs.Dirent, name string) error {
Read: true,
Execute: true,
}
- if err := d.Inode.CheckPermission(ctx, perms); err != nil {
- return err
- }
+ return d.Inode.CheckPermission(ctx, perms)
+}
- // If they claim it's a directory, then make sure.
- //
- // N.B. we reject directories below, but we must first reject
- // non-directories passed as directories.
- if len(name) > 0 && name[len(name)-1] == '/' && !fs.IsDir(d.Inode.StableAttr) {
- return syserror.ENOTDIR
+// checkIsRegularFile prevents us from trying to execute a directory, pipe, etc.
+func checkIsRegularFile(ctx context.Context, d *fs.Dirent, filename string) error {
+ attr := d.Inode.StableAttr
+ if !fs.IsRegular(attr) {
+ ctx.Infof("%s is not regular: %v", filename, attr)
+ return syserror.EACCES
}
+ return nil
+}
- // No exec-ing directories, pipes, etc!
- if !fs.IsRegular(d.Inode.StableAttr) {
- ctx.Infof("%s is not regular: %v", name, d.Inode.StableAttr)
+// checkPread checks whether we can read the file at arbitrary offsets.
+func checkPread(ctx context.Context, f *fs.File, filename string) error {
+ if !f.Flags().Pread {
+ ctx.Infof("%s cannot be read at an offset: %+v", filename, f.Flags())
return syserror.EACCES
}
-
return nil
-
}
// allocStack allocates and maps a stack in to any available part of the address space.
@@ -173,45 +224,49 @@ const (
maxLoaderAttempts = 6
)
-// loadBinary loads a binary that is pointed to by "file". If nil, the path
-// "filename" is resolved and loaded.
+// loadExecutable loads an executable that is pointed to by args.File. If nil,
+// the path args.Filename is resolved and loaded. If the executable is an
+// interpreter script rather than an ELF, the binary of the corresponding
+// interpreter will be loaded.
//
// It returns:
// * loadedELF, description of the loaded binary
// * arch.Context matching the binary arch
// * fs.Dirent of the binary file
-// * Possibly updated argv
-func loadBinary(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, remainingTraversals *uint, features *cpuid.FeatureSet, filename string, passedFile *fs.File, argv []string) (loadedELF, arch.Context, *fs.Dirent, []string, error) {
+// * Possibly updated args.Argv
+func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, *fs.Dirent, []string, error) {
for i := 0; i < maxLoaderAttempts; i++ {
var (
d *fs.Dirent
- f *fs.File
err error
)
- if passedFile == nil {
- d, f, err = openPath(ctx, mounts, root, wd, remainingTraversals, filename)
-
+ if args.File == nil {
+ d, args.File, err = openPath(ctx, args)
+ // We will return d in the successful case, but defer a DecRef for the
+ // sake of intermediate loops and failure cases.
+ if d != nil {
+ defer d.DecRef()
+ }
+ if args.File != nil {
+ defer args.File.DecRef()
+ }
} else {
- d, f, err = openFile(ctx, passedFile, nil, "")
- // Set to nil in case we loop on a Interpreter Script.
- passedFile = nil
+ d = args.File.Dirent
+ d.IncRef()
+ defer d.DecRef()
+ err = checkFile(ctx, args.File, args.Filename)
}
-
if err != nil {
- ctx.Infof("Error opening %s: %v", filename, err)
+ ctx.Infof("Error opening %s: %v", args.Filename, err)
return loadedELF{}, nil, nil, nil, err
}
- defer f.DecRef()
- // We will return d in the successful case, but defer a DecRef
- // for intermediate loops and failure cases.
- defer d.DecRef()
// Check the header. Is this an ELF or interpreter script?
var hdr [4]uint8
// N.B. We assume that reading from a regular file cannot block.
- _, err = readFull(ctx, f, usermem.BytesIOSequence(hdr[:]), 0)
- // Allow unexpected EOF, as a valid executable could be only three
- // bytes (e.g., #!a).
+ _, err = readFull(ctx, args.File, usermem.BytesIOSequence(hdr[:]), 0)
+ // Allow unexpected EOF, as a valid executable could be only three bytes
+ // (e.g., #!a).
if err != nil && err != io.ErrUnexpectedEOF {
if err == io.EOF {
err = syserror.ENOEXEC
@@ -221,33 +276,38 @@ func loadBinary(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamesp
switch {
case bytes.Equal(hdr[:], []byte(elfMagic)):
- loaded, ac, err := loadELF(ctx, m, mounts, root, wd, remainingTraversals, features, f)
+ loaded, ac, err := loadELF(ctx, args)
if err != nil {
ctx.Infof("Error loading ELF: %v", err)
return loadedELF{}, nil, nil, nil, err
}
// An ELF is always terminal. Hold on to d.
d.IncRef()
- return loaded, ac, d, argv, err
+ return loaded, ac, d, args.Argv, err
case bytes.Equal(hdr[:2], []byte(interpreterScriptMagic)):
- newpath, newargv, err := parseInterpreterScript(ctx, filename, f, argv)
+ if args.CloseOnExec {
+ return loadedELF{}, nil, nil, nil, syserror.ENOENT
+ }
+ args.Filename, args.Argv, err = parseInterpreterScript(ctx, args.Filename, args.File, args.Argv)
if err != nil {
ctx.Infof("Error loading interpreter script: %v", err)
return loadedELF{}, nil, nil, nil, err
}
- filename = newpath
- argv = newargv
+ // Refresh the traversal limit for the interpreter.
+ *args.RemainingTraversals = linux.MaxSymlinkTraversals
default:
ctx.Infof("Unknown magic: %v", hdr)
return loadedELF{}, nil, nil, nil, syserror.ENOEXEC
}
+ // Set to nil in case we loop on a Interpreter Script.
+ args.File = nil
}
return loadedELF{}, nil, nil, nil, syserror.ELOOP
}
-// Load loads "file" into a MemoryManager. If file is nil, the path "filename"
-// is resolved and loaded instead.
+// Load loads args.File into a MemoryManager. If args.File is nil, the path
+// args.Filename is resolved and loaded instead.
//
// If Load returns ErrSwitchFile it should be called again with the returned
// path and argv.
@@ -255,37 +315,37 @@ func loadBinary(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamesp
// Preconditions:
// * The Task MemoryManager is empty.
// * Load is called on the Task goroutine.
-func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, fs *cpuid.FeatureSet, filename string, file *fs.File, argv, envv []string, extraAuxv []arch.AuxEntry, vdso *VDSO) (abi.OS, arch.Context, string, *syserr.Error) {
- // Load the binary itself.
- loaded, ac, d, argv, err := loadBinary(ctx, m, mounts, root, wd, maxTraversals, fs, filename, file, argv)
+func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *VDSO) (abi.OS, arch.Context, string, *syserr.Error) {
+ // Load the executable itself.
+ loaded, ac, d, newArgv, err := loadExecutable(ctx, args)
if err != nil {
- return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", filename, err), syserr.FromError(err).ToLinux())
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux())
}
defer d.DecRef()
// Load the VDSO.
- vdsoAddr, err := loadVDSO(ctx, m, vdso, loaded)
+ vdsoAddr, err := loadVDSO(ctx, args.MemoryManager, vdso, loaded)
if err != nil {
return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Error loading VDSO: %v", err), syserr.FromError(err).ToLinux())
}
// Setup the heap. brk starts at the next page after the end of the
- // binary. Userspace can assume that the remainer of the page after
+ // executable. Userspace can assume that the remainer of the page after
// loaded.end is available for its use.
e, ok := loaded.end.RoundUp()
if !ok {
return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("brk overflows: %#x", loaded.end), linux.ENOEXEC)
}
- m.BrkSetup(ctx, e)
+ args.MemoryManager.BrkSetup(ctx, e)
// Allocate our stack.
- stack, err := allocStack(ctx, m, ac)
+ stack, err := allocStack(ctx, args.MemoryManager, ac)
if err != nil {
return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to allocate stack: %v", err), syserr.FromError(err).ToLinux())
}
// Push the original filename to the stack, for AT_EXECFN.
- execfn, err := stack.Push(filename)
+ execfn, err := stack.Push(args.Filename)
if err != nil {
return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to push exec filename: %v", err), syserr.FromError(err).ToLinux())
}
@@ -319,11 +379,12 @@ func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, r
}...)
auxv = append(auxv, extraAuxv...)
- sl, err := stack.Load(argv, envv, auxv)
+ sl, err := stack.Load(newArgv, args.Envv, auxv)
if err != nil {
return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load stack: %v", err), syserr.FromError(err).ToLinux())
}
+ m := args.MemoryManager
m.SetArgvStart(sl.ArgvStart)
m.SetArgvEnd(sl.ArgvEnd)
m.SetEnvvStart(sl.EnvvStart)
@@ -334,7 +395,7 @@ func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, r
ac.SetIP(uintptr(loaded.entry))
ac.SetStack(uintptr(stack.Bottom))
- name := path.Base(filename)
+ name := path.Base(args.Filename)
if len(name) > linux.TASK_COMM_LEN-1 {
name = name[:linux.TASK_COMM_LEN-1]
}
diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go
index ada28aea3..df8a81907 100644
--- a/pkg/sentry/loader/vdso.go
+++ b/pkg/sentry/loader/vdso.go
@@ -268,6 +268,8 @@ func PrepareVDSO(ctx context.Context, mfp pgalloc.MemoryFileProvider) (*VDSO, er
// some applications may not be able to handle multiple [vdso]
// hints.
vdso: mm.NewSpecialMappable("", mfp, vdso),
+ os: info.os,
+ arch: info.arch,
phdrs: info.phdrs,
}, nil
}
diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD
index 3ef84245b..112794e9c 100644
--- a/pkg/sentry/memmap/BUILD
+++ b/pkg/sentry/memmap/BUILD
@@ -41,7 +41,6 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/log",
- "//pkg/refs",
"//pkg/sentry/context",
"//pkg/sentry/platform",
"//pkg/sentry/usermem",
diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go
index 03b99aaea..16a722a13 100644
--- a/pkg/sentry/memmap/memmap.go
+++ b/pkg/sentry/memmap/memmap.go
@@ -18,7 +18,6 @@ package memmap
import (
"fmt"
- "gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usermem"
@@ -235,8 +234,11 @@ type InvalidateOpts struct {
// coincidental; fs.File implements MappingIdentity, and some
// fs.InodeOperations implement Mappable.)
type MappingIdentity interface {
- // MappingIdentity is reference-counted.
- refs.RefCounter
+ // IncRef increments the MappingIdentity's reference count.
+ IncRef()
+
+ // DecRef decrements the MappingIdentity's reference count.
+ DecRef()
// MappedName returns the application-visible name shown in
// /proc/[pid]/maps.
diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD
index a804b8b5c..839931f67 100644
--- a/pkg/sentry/mm/BUILD
+++ b/pkg/sentry/mm/BUILD
@@ -118,9 +118,9 @@ go_library(
"//pkg/sentry/safemem",
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
+ "//pkg/syncutil",
"//pkg/syserror",
"//pkg/tcpip/buffer",
- "//third_party/gvsync",
],
)
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
index f350e0109..58a5c186d 100644
--- a/pkg/sentry/mm/mm.go
+++ b/pkg/sentry/mm/mm.go
@@ -44,7 +44,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/third_party/gvsync"
+ "gvisor.dev/gvisor/pkg/syncutil"
)
// MemoryManager implements a virtual address space.
@@ -82,7 +82,7 @@ type MemoryManager struct {
users int32
// mappingMu is analogous to Linux's struct mm_struct::mmap_sem.
- mappingMu gvsync.DowngradableRWMutex `state:"nosave"`
+ mappingMu syncutil.DowngradableRWMutex `state:"nosave"`
// vmas stores virtual memory areas. Since vmas are stored by value,
// clients should usually use vmaIterator.ValuePtr() instead of
@@ -125,7 +125,7 @@ type MemoryManager struct {
// activeMu is loosely analogous to Linux's struct
// mm_struct::page_table_lock.
- activeMu gvsync.DowngradableRWMutex `state:"nosave"`
+ activeMu syncutil.DowngradableRWMutex `state:"nosave"`
// pmas stores platform mapping areas used to implement vmas. Since pmas
// are stored by value, clients should usually use pmaIterator.ValuePtr()
diff --git a/pkg/sentry/pgalloc/save_restore.go b/pkg/sentry/pgalloc/save_restore.go
index 1effc7735..aafce1d00 100644
--- a/pkg/sentry/pgalloc/save_restore.go
+++ b/pkg/sentry/pgalloc/save_restore.go
@@ -16,6 +16,7 @@ package pgalloc
import (
"bytes"
+ "context"
"fmt"
"io"
"runtime"
@@ -29,7 +30,7 @@ import (
)
// SaveTo writes f's state to the given stream.
-func (f *MemoryFile) SaveTo(w io.Writer) error {
+func (f *MemoryFile) SaveTo(ctx context.Context, w io.Writer) error {
// Wait for reclaim.
f.mu.Lock()
defer f.mu.Unlock()
@@ -78,10 +79,10 @@ func (f *MemoryFile) SaveTo(w io.Writer) error {
}
// Save metadata.
- if err := state.Save(w, &f.fileSize, nil); err != nil {
+ if err := state.Save(ctx, w, &f.fileSize, nil); err != nil {
return err
}
- if err := state.Save(w, &f.usage, nil); err != nil {
+ if err := state.Save(ctx, w, &f.usage, nil); err != nil {
return err
}
@@ -114,9 +115,9 @@ func (f *MemoryFile) SaveTo(w io.Writer) error {
}
// LoadFrom loads MemoryFile state from the given stream.
-func (f *MemoryFile) LoadFrom(r io.Reader) error {
+func (f *MemoryFile) LoadFrom(ctx context.Context, r io.Reader) error {
// Load metadata.
- if err := state.Load(r, &f.fileSize, nil); err != nil {
+ if err := state.Load(ctx, r, &f.fileSize, nil); err != nil {
return err
}
if err := f.file.Truncate(f.fileSize); err != nil {
@@ -124,7 +125,7 @@ func (f *MemoryFile) LoadFrom(r io.Reader) error {
}
newMappings := make([]uintptr, f.fileSize>>chunkShift)
f.mappings.Store(newMappings)
- if err := state.Load(r, &f.usage, nil); err != nil {
+ if err := state.Load(ctx, r, &f.usage, nil); err != nil {
return err
}
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index 31fa48ec5..f3afd98da 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -12,19 +12,30 @@ go_library(
"bluepill_amd64.go",
"bluepill_amd64.s",
"bluepill_amd64_unsafe.go",
+ "bluepill_arm64.go",
+ "bluepill_arm64.s",
+ "bluepill_arm64_unsafe.go",
"bluepill_fault.go",
"bluepill_unsafe.go",
"context.go",
- "filters.go",
+ "filters_amd64.go",
+ "filters_arm64.go",
"kvm.go",
"kvm_amd64.go",
"kvm_amd64_unsafe.go",
+ "kvm_arm64.go",
+ "kvm_arm64_unsafe.go",
"kvm_const.go",
+ "kvm_const_arm64.go",
"machine.go",
"machine_amd64.go",
"machine_amd64_unsafe.go",
+ "machine_arm64.go",
+ "machine_arm64_unsafe.go",
"machine_unsafe.go",
"physical_map.go",
+ "physical_map_amd64.go",
+ "physical_map_arm64.go",
"virtual_map.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/platform/kvm",
diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go
index acd41f73d..ea8b9632e 100644
--- a/pkg/sentry/platform/kvm/address_space.go
+++ b/pkg/sentry/platform/kvm/address_space.go
@@ -127,7 +127,7 @@ func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.Ac
// not have physical mappings, the KVM module may inject
// spurious exceptions when emulation fails (i.e. it tries to
// emulate because the RIP is pointed at those pages).
- as.machine.mapPhysical(physical, length)
+ as.machine.mapPhysical(physical, length, physicalRegions, _KVM_MEM_FLAGS_NONE)
// Install the page table mappings. Note that the ordering is
// important; if the pagetable mappings were installed before
diff --git a/pkg/sentry/platform/kvm/allocator.go b/pkg/sentry/platform/kvm/allocator.go
index 80942e9c9..3f35414bb 100644
--- a/pkg/sentry/platform/kvm/allocator.go
+++ b/pkg/sentry/platform/kvm/allocator.go
@@ -54,7 +54,7 @@ func (a allocator) PhysicalFor(ptes *pagetables.PTEs) uintptr {
//
//go:nosplit
func (a allocator) LookupPTEs(physical uintptr) *pagetables.PTEs {
- virtualStart, physicalStart, _, ok := calculateBluepillFault(physical)
+ virtualStart, physicalStart, _, ok := calculateBluepillFault(physical, physicalRegions)
if !ok {
panic(fmt.Sprintf("LookupPTEs failed for 0x%x", physical))
}
diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go
index 043de51b3..30dbb74d6 100644
--- a/pkg/sentry/platform/kvm/bluepill.go
+++ b/pkg/sentry/platform/kvm/bluepill.go
@@ -20,6 +20,7 @@ import (
"syscall"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
"gvisor.dev/gvisor/pkg/sentry/platform/safecopy"
)
@@ -36,6 +37,18 @@ func sighandler()
func dieTrampoline()
var (
+ // bounceSignal is the signal used for bouncing KVM.
+ //
+ // We use SIGCHLD because it is not masked by the runtime, and
+ // it will be ignored properly by other parts of the kernel.
+ bounceSignal = syscall.SIGCHLD
+
+ // bounceSignalMask has only bounceSignal set.
+ bounceSignalMask = uint64(1 << (uint64(bounceSignal) - 1))
+
+ // bounce is the interrupt vector used to return to the kernel.
+ bounce = uint32(ring0.VirtualizationException)
+
// savedHandler is a pointer to the previous handler.
//
// This is called by bluepillHandler.
@@ -45,6 +58,13 @@ var (
dieTrampolineAddr uintptr
)
+// redpill invokes a syscall with -1.
+//
+//go:nosplit
+func redpill() {
+ syscall.RawSyscall(^uintptr(0), 0, 0, 0)
+}
+
// dieHandler is called by dieTrampoline.
//
//go:nosplit
@@ -73,8 +93,8 @@ func (c *vCPU) die(context *arch.SignalContext64, msg string) {
func init() {
// Install the handler.
- if err := safecopy.ReplaceSignalHandler(syscall.SIGSEGV, reflect.ValueOf(sighandler).Pointer(), &savedHandler); err != nil {
- panic(fmt.Sprintf("Unable to set handler for signal %d: %v", syscall.SIGSEGV, err))
+ if err := safecopy.ReplaceSignalHandler(bluepillSignal, reflect.ValueOf(sighandler).Pointer(), &savedHandler); err != nil {
+ panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err))
}
// Extract the address for the trampoline.
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go
index 421c88220..133c2203d 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.go
@@ -24,26 +24,10 @@ import (
)
var (
- // bounceSignal is the signal used for bouncing KVM.
- //
- // We use SIGCHLD because it is not masked by the runtime, and
- // it will be ignored properly by other parts of the kernel.
- bounceSignal = syscall.SIGCHLD
-
- // bounceSignalMask has only bounceSignal set.
- bounceSignalMask = uint64(1 << (uint64(bounceSignal) - 1))
-
- // bounce is the interrupt vector used to return to the kernel.
- bounce = uint32(ring0.VirtualizationException)
+ // The action for bluepillSignal is changed by sigaction().
+ bluepillSignal = syscall.SIGSEGV
)
-// redpill on amd64 invokes a syscall with -1.
-//
-//go:nosplit
-func redpill() {
- syscall.RawSyscall(^uintptr(0), 0, 0, 0)
-}
-
// bluepillArchEnter is called during bluepillEnter.
//
//go:nosplit
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
index 9d8af143e..a63a6a071 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
@@ -23,13 +23,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
)
-// bluepillArchContext returns the arch-specific context.
-//
-//go:nosplit
-func bluepillArchContext(context unsafe.Pointer) *arch.SignalContext64 {
- return &((*arch.UContext64)(context).MContext)
-}
-
// dieArchSetup initializes the state for dieTrampoline.
//
// The amd64 dieTrampoline requires the vCPU to be set in BX, and the last RIP
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go
new file mode 100644
index 000000000..552341721
--- /dev/null
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.go
@@ -0,0 +1,79 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package kvm
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+)
+
+var (
+ // The action for bluepillSignal is changed by sigaction().
+ bluepillSignal = syscall.SIGILL
+)
+
+// bluepillArchEnter is called during bluepillEnter.
+//
+//go:nosplit
+func bluepillArchEnter(context *arch.SignalContext64) (c *vCPU) {
+ c = vCPUPtr(uintptr(context.Regs[8]))
+ regs := c.CPU.Registers()
+ regs.Regs = context.Regs
+ regs.Sp = context.Sp
+ regs.Pc = context.Pc
+ regs.Pstate = context.Pstate
+ regs.Pstate &^= uint64(ring0.KernelFlagsClear)
+ regs.Pstate |= ring0.KernelFlagsSet
+ return
+}
+
+// bluepillArchExit is called during bluepillEnter.
+//
+//go:nosplit
+func bluepillArchExit(c *vCPU, context *arch.SignalContext64) {
+ regs := c.CPU.Registers()
+ context.Regs = regs.Regs
+ context.Sp = regs.Sp
+ context.Pc = regs.Pc
+ context.Pstate = regs.Pstate
+ context.Pstate &^= uint64(ring0.UserFlagsClear)
+ context.Pstate |= ring0.UserFlagsSet
+}
+
+// KernelSyscall handles kernel syscalls.
+//
+//go:nosplit
+func (c *vCPU) KernelSyscall() {
+ regs := c.Registers()
+ if regs.Regs[8] != ^uint64(0) {
+ regs.Pc -= 4 // Rewind.
+ }
+ ring0.Halt()
+}
+
+// KernelException handles kernel exceptions.
+//
+//go:nosplit
+func (c *vCPU) KernelException(vector ring0.Vector) {
+ regs := c.Registers()
+ if vector == ring0.Vector(bounce) {
+ regs.Pc = 0
+ }
+ ring0.Halt()
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s
new file mode 100644
index 000000000..c61700892
--- /dev/null
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.s
@@ -0,0 +1,87 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// VCPU_CPU is the location of the CPU in the vCPU struct.
+//
+// This is guaranteed to be zero.
+#define VCPU_CPU 0x0
+
+// CPU_SELF is the self reference in ring0's percpu.
+//
+// This is guaranteed to be zero.
+#define CPU_SELF 0x0
+
+// Context offsets.
+//
+// Only limited use of the context is done in the assembly stub below, most is
+// done in the Go handlers.
+#define SIGINFO_SIGNO 0x0
+#define CONTEXT_PC 0x1B8
+#define CONTEXT_R0 0xB8
+
+// See bluepill.go.
+TEXT ·bluepill(SB),NOSPLIT,$0
+begin:
+ MOVD vcpu+0(FP), R8
+ MOVD $VCPU_CPU(R8), R9
+ ORR $0xffff000000000000, R9, R9
+ // Trigger sigill.
+ // In ring0.Start(), the value of R8 will be stored into tpidr_el1.
+ // When the context was loaded into vcpu successfully,
+ // we will check if the value of R10 and R9 are the same.
+ WORD $0xd538d08a // MRS TPIDR_EL1, R10
+check_vcpu:
+ CMP R10, R9
+ BEQ right_vCPU
+wrong_vcpu:
+ CALL ·redpill(SB)
+ B begin
+right_vCPU:
+ RET
+
+// sighandler: see bluepill.go for documentation.
+//
+// The arguments are the following:
+//
+// R0 - The signal number.
+// R1 - Pointer to siginfo_t structure.
+// R2 - Pointer to ucontext structure.
+//
+TEXT ·sighandler(SB),NOSPLIT,$0
+ // si_signo should be sigill.
+ MOVD SIGINFO_SIGNO(R1), R7
+ CMPW $4, R7
+ BNE fallback
+
+ MOVD CONTEXT_PC(R2), R7
+ CMPW $0, R7
+ BEQ fallback
+
+ MOVD R2, 8(RSP)
+ BL ·bluepillHandler(SB) // Call the handler.
+
+ RET
+
+fallback:
+ // Jump to the previous signal handler.
+ MOVD ·savedHandler(SB), R7
+ B (R7)
+
+// dieTrampoline: see bluepill.go, bluepill_arm64_unsafe.go for documentation.
+TEXT ·dieTrampoline(SB),NOSPLIT,$0
+ // TODO(gvisor.dev/issue/1249): dieTrampoline supporting for Arm64.
+ MOVD R9, 8(RSP)
+ BL ·dieHandler(SB)
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
new file mode 100644
index 000000000..e5fac0d6a
--- /dev/null
+++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
@@ -0,0 +1,28 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package kvm
+
+import (
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+//go:nosplit
+func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) {
+ // TODO(gvisor.dev/issue/1249): dieTrampoline supporting for Arm64.
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go
index b97476053..f6459cda9 100644
--- a/pkg/sentry/platform/kvm/bluepill_fault.go
+++ b/pkg/sentry/platform/kvm/bluepill_fault.go
@@ -46,9 +46,9 @@ func yield() {
// calculateBluepillFault calculates the fault address range.
//
//go:nosplit
-func calculateBluepillFault(physical uintptr) (virtualStart, physicalStart, length uintptr, ok bool) {
+func calculateBluepillFault(physical uintptr, phyRegions []physicalRegion) (virtualStart, physicalStart, length uintptr, ok bool) {
alignedPhysical := physical &^ uintptr(usermem.PageSize-1)
- for _, pr := range physicalRegions {
+ for _, pr := range phyRegions {
end := pr.physical + pr.length
if physical < pr.physical || physical >= end {
continue
@@ -77,12 +77,12 @@ func calculateBluepillFault(physical uintptr) (virtualStart, physicalStart, leng
// The corresponding virtual address is returned. This may throw on error.
//
//go:nosplit
-func handleBluepillFault(m *machine, physical uintptr) (uintptr, bool) {
+func handleBluepillFault(m *machine, physical uintptr, phyRegions []physicalRegion, flags uint32) (uintptr, bool) {
// Paging fault: we need to map the underlying physical pages for this
// fault. This all has to be done in this function because we're in a
// signal handler context. (We can't call any functions that might
// split the stack.)
- virtualStart, physicalStart, length, ok := calculateBluepillFault(physical)
+ virtualStart, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions)
if !ok {
return 0, false
}
@@ -96,7 +96,7 @@ func handleBluepillFault(m *machine, physical uintptr) (uintptr, bool) {
yield() // Race with another call.
slot = atomic.SwapUint32(&m.nextSlot, ^uint32(0))
}
- errno := m.setMemoryRegion(int(slot), physicalStart, length, virtualStart)
+ errno := m.setMemoryRegion(int(slot), physicalStart, length, virtualStart, flags)
if errno == 0 {
// Successfully added region; we can increment nextSlot and
// allow another set to proceed here.
diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go
index 7e8e9f42a..9add7c944 100644
--- a/pkg/sentry/platform/kvm/bluepill_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go
@@ -13,7 +13,7 @@
// limitations under the License.
// +build go1.12
-// +build !go1.14
+// +build !go1.15
// Check go:linkname function signatures when updating Go version.
@@ -23,6 +23,8 @@ import (
"sync/atomic"
"syscall"
"unsafe"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
)
//go:linkname throw runtime.throw
@@ -49,6 +51,13 @@ func uintptrValue(addr *byte) uintptr {
return (uintptr)(unsafe.Pointer(addr))
}
+// bluepillArchContext returns the UContext64.
+//
+//go:nosplit
+func bluepillArchContext(context unsafe.Pointer) *arch.SignalContext64 {
+ return &((*arch.UContext64)(context).MContext)
+}
+
// bluepillHandler is called from the signal stub.
//
// The world may be stopped while this is executing, and it executes on the
@@ -80,13 +89,17 @@ func bluepillHandler(context unsafe.Pointer) {
// interrupted KVM. Since we're in a signal handler
// currently, all signals are masked and the signal
// must have been delivered directly to this thread.
+ timeout := syscall.Timespec{}
sig, _, errno := syscall.RawSyscall6(
syscall.SYS_RT_SIGTIMEDWAIT,
uintptr(unsafe.Pointer(&bounceSignalMask)),
- 0, // siginfo.
- 0, // timeout.
- 8, // sigset size.
+ 0, // siginfo.
+ uintptr(unsafe.Pointer(&timeout)), // timeout.
+ 8, // sigset size.
0, 0)
+ if errno == syscall.EAGAIN {
+ continue
+ }
if errno != 0 {
throw("error waiting for pending signal")
}
@@ -162,7 +175,7 @@ func bluepillHandler(context unsafe.Pointer) {
// For MMIO, the physical address is the first data item.
physical := uintptr(c.runData.data[0])
- virtual, ok := handleBluepillFault(c.machine, physical)
+ virtual, ok := handleBluepillFault(c.machine, physical, physicalRegions, _KVM_MEM_FLAGS_NONE)
if !ok {
c.die(bluepillArchContext(context), "invalid physical address")
return
diff --git a/pkg/sentry/platform/kvm/filters.go b/pkg/sentry/platform/kvm/filters_amd64.go
index 7d949f1dd..7d949f1dd 100644
--- a/pkg/sentry/platform/kvm/filters.go
+++ b/pkg/sentry/platform/kvm/filters_amd64.go
diff --git a/pkg/sentry/platform/kvm/filters_arm64.go b/pkg/sentry/platform/kvm/filters_arm64.go
new file mode 100644
index 000000000..9245d07c2
--- /dev/null
+++ b/pkg/sentry/platform/kvm/filters_arm64.go
@@ -0,0 +1,32 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
+// SyscallFilters returns syscalls made exclusively by the KVM platform.
+func (*KVM) SyscallFilters() seccomp.SyscallRules {
+ return seccomp.SyscallRules{
+ syscall.SYS_IOCTL: {},
+ syscall.SYS_MMAP: {},
+ syscall.SYS_RT_SIGSUSPEND: {},
+ syscall.SYS_RT_SIGTIMEDWAIT: {},
+ 0xffffffffffffffff: {}, // KVM uses syscall -1 to transition to host.
+ }
+}
diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go
index ee4cd2f4d..f2c2c059e 100644
--- a/pkg/sentry/platform/kvm/kvm.go
+++ b/pkg/sentry/platform/kvm/kvm.go
@@ -21,7 +21,6 @@ import (
"sync"
"syscall"
- "gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
@@ -56,9 +55,7 @@ func New(deviceFile *os.File) (*KVM, error) {
// Ensure global initialization is done.
globalOnce.Do(func() {
- physicalInit()
- globalErr = updateSystemValues(int(fd))
- ring0.Init(cpuid.HostFeatureSet())
+ globalErr = updateGlobalOnce(int(fd))
})
if globalErr != nil {
return nil, globalErr
diff --git a/pkg/sentry/platform/kvm/kvm_amd64.go b/pkg/sentry/platform/kvm/kvm_amd64.go
index 5d8ef4761..c5a6f9c7d 100644
--- a/pkg/sentry/platform/kvm/kvm_amd64.go
+++ b/pkg/sentry/platform/kvm/kvm_amd64.go
@@ -17,6 +17,7 @@
package kvm
import (
+ "gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
)
@@ -211,3 +212,11 @@ type cpuidEntries struct {
_ uint32
entries [_KVM_NR_CPUID_ENTRIES]cpuidEntry
}
+
+// updateGlobalOnce does global initialization. It has to be called only once.
+func updateGlobalOnce(fd int) error {
+ physicalInit()
+ err := updateSystemValues(int(fd))
+ ring0.Init(cpuid.HostFeatureSet())
+ return err
+}
diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go
new file mode 100644
index 000000000..2319c86d3
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm_arm64.go
@@ -0,0 +1,83 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package kvm
+
+import (
+ "syscall"
+)
+
+// userMemoryRegion is a region of physical memory.
+//
+// This mirrors kvm_memory_region.
+type userMemoryRegion struct {
+ slot uint32
+ flags uint32
+ guestPhysAddr uint64
+ memorySize uint64
+ userspaceAddr uint64
+}
+
+type kvmOneReg struct {
+ id uint64
+ addr uint64
+}
+
+const KVM_NR_SPSR = 5
+
+type userFpsimdState struct {
+ vregs [64]uint64
+ fpsr uint32
+ fpcr uint32
+ reserved [2]uint32
+}
+
+type userRegs struct {
+ Regs syscall.PtraceRegs
+ sp_el1 uint64
+ elr_el1 uint64
+ spsr [KVM_NR_SPSR]uint64
+ fpRegs userFpsimdState
+}
+
+// runData is the run structure. This may be mapped for synchronous register
+// access (although that doesn't appear to be supported by my kernel at least).
+//
+// This mirrors kvm_run.
+type runData struct {
+ requestInterruptWindow uint8
+ _ [7]uint8
+
+ exitReason uint32
+ readyForInterruptInjection uint8
+ ifFlag uint8
+ _ [2]uint8
+
+ cr8 uint64
+ apicBase uint64
+
+ // This is the union data for exits. Interpretation depends entirely on
+ // the exitReason above (see vCPU code for more information).
+ data [32]uint64
+}
+
+// updateGlobalOnce does global initialization. It has to be called only once.
+func updateGlobalOnce(fd int) error {
+ physicalInit()
+ err := updateSystemValues(int(fd))
+ updateVectorTable()
+ return err
+}
diff --git a/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go
new file mode 100644
index 000000000..6531bae1d
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go
@@ -0,0 +1,39 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package kvm
+
+import (
+ "fmt"
+ "syscall"
+)
+
+var (
+ runDataSize int
+)
+
+func updateSystemValues(fd int) error {
+ // Extract the mmap size.
+ sz, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(fd), _KVM_GET_VCPU_MMAP_SIZE, 0)
+ if errno != 0 {
+ return fmt.Errorf("getting VCPU mmap size: %v", errno)
+ }
+ // Save the data.
+ runDataSize = int(sz)
+
+ // Success.
+ return nil
+}
diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go
index d05f05c29..1d5c77ff4 100644
--- a/pkg/sentry/platform/kvm/kvm_const.go
+++ b/pkg/sentry/platform/kvm/kvm_const.go
@@ -49,11 +49,13 @@ const (
_KVM_EXIT_SHUTDOWN = 0x8
_KVM_EXIT_FAIL_ENTRY = 0x9
_KVM_EXIT_INTERNAL_ERROR = 0x11
+ _KVM_EXIT_SYSTEM_EVENT = 0x18
)
// KVM capability options.
const (
- _KVM_CAP_MAX_VCPUS = 0x42
+ _KVM_CAP_MAX_VCPUS = 0x42
+ _KVM_CAP_ARM_VM_IPA_SIZE = 0xa5
)
// KVM limits.
@@ -62,3 +64,10 @@ const (
_KVM_NR_INTERRUPTS = 0x100
_KVM_NR_CPUID_ENTRIES = 0x100
)
+
+// KVM kvm_memory_region::flags.
+const (
+ _KVM_MEM_LOG_DIRTY_PAGES = uint32(1) << 0
+ _KVM_MEM_READONLY = uint32(1) << 1
+ _KVM_MEM_FLAGS_NONE = 0
+)
diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go
new file mode 100644
index 000000000..5a74c6e36
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go
@@ -0,0 +1,132 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+// KVM ioctls for Arm64.
+const (
+ _KVM_GET_ONE_REG = 0x4010aeab
+ _KVM_SET_ONE_REG = 0x4010aeac
+
+ _KVM_ARM_PREFERRED_TARGET = 0x8020aeaf
+ _KVM_ARM_VCPU_INIT = 0x4020aeae
+ _KVM_ARM64_REGS_PSTATE = 0x6030000000100042
+ _KVM_ARM64_REGS_SP_EL1 = 0x6030000000100044
+ _KVM_ARM64_REGS_R0 = 0x6030000000100000
+ _KVM_ARM64_REGS_R1 = 0x6030000000100002
+ _KVM_ARM64_REGS_R2 = 0x6030000000100004
+ _KVM_ARM64_REGS_R3 = 0x6030000000100006
+ _KVM_ARM64_REGS_R8 = 0x6030000000100010
+ _KVM_ARM64_REGS_R18 = 0x6030000000100024
+ _KVM_ARM64_REGS_PC = 0x6030000000100040
+ _KVM_ARM64_REGS_MAIR_EL1 = 0x603000000013c510
+ _KVM_ARM64_REGS_TCR_EL1 = 0x603000000013c102
+ _KVM_ARM64_REGS_TTBR0_EL1 = 0x603000000013c100
+ _KVM_ARM64_REGS_TTBR1_EL1 = 0x603000000013c101
+ _KVM_ARM64_REGS_SCTLR_EL1 = 0x603000000013c080
+ _KVM_ARM64_REGS_CPACR_EL1 = 0x603000000013c082
+ _KVM_ARM64_REGS_VBAR_EL1 = 0x603000000013c600
+)
+
+// Arm64: Architectural Feature Access Control Register EL1.
+const (
+ _FPEN_NOTRAP = 0x3
+ _FPEN_SHIFT = 0x20
+)
+
+// Arm64: System Control Register EL1.
+const (
+ _SCTLR_M = 1 << 0
+ _SCTLR_C = 1 << 2
+ _SCTLR_I = 1 << 12
+)
+
+// Arm64: Translation Control Register EL1.
+const (
+ _TCR_IPS_40BITS = 2 << 32 // PA=40
+ _TCR_IPS_48BITS = 5 << 32 // PA=48
+
+ _TCR_T0SZ_OFFSET = 0
+ _TCR_T1SZ_OFFSET = 16
+ _TCR_IRGN0_SHIFT = 8
+ _TCR_IRGN1_SHIFT = 24
+ _TCR_ORGN0_SHIFT = 10
+ _TCR_ORGN1_SHIFT = 26
+ _TCR_SH0_SHIFT = 12
+ _TCR_SH1_SHIFT = 28
+ _TCR_TG0_SHIFT = 14
+ _TCR_TG1_SHIFT = 30
+
+ _TCR_T0SZ_VA48 = 64 - 48 // VA=48
+ _TCR_T1SZ_VA48 = 64 - 48 // VA=48
+
+ _TCR_ASID16 = 1 << 36
+ _TCR_TBI0 = 1 << 37
+
+ _TCR_TXSZ_VA48 = (_TCR_T0SZ_VA48 << _TCR_T0SZ_OFFSET) | (_TCR_T1SZ_VA48 << _TCR_T1SZ_OFFSET)
+
+ _TCR_TG0_4K = 0 << _TCR_TG0_SHIFT // 4K
+ _TCR_TG0_64K = 1 << _TCR_TG0_SHIFT // 64K
+
+ _TCR_TG1_4K = 2 << _TCR_TG1_SHIFT
+
+ _TCR_TG_FLAGS = _TCR_TG0_4K | _TCR_TG1_4K
+
+ _TCR_IRGN0_WBWA = 1 << _TCR_IRGN0_SHIFT
+ _TCR_IRGN1_WBWA = 1 << _TCR_IRGN1_SHIFT
+ _TCR_IRGN_WBWA = _TCR_IRGN0_WBWA | _TCR_IRGN1_WBWA
+
+ _TCR_ORGN0_WBWA = 1 << _TCR_ORGN0_SHIFT
+ _TCR_ORGN1_WBWA = 1 << _TCR_ORGN1_SHIFT
+
+ _TCR_ORGN_WBWA = _TCR_ORGN0_WBWA | _TCR_ORGN1_WBWA
+
+ _TCR_SHARED = (3 << _TCR_SH0_SHIFT) | (3 << _TCR_SH1_SHIFT)
+
+ _TCR_CACHE_FLAGS = _TCR_IRGN_WBWA | _TCR_ORGN_WBWA
+)
+
+// Arm64: Memory Attribute Indirection Register EL1.
+const (
+ _MT_DEVICE_nGnRnE = 0
+ _MT_DEVICE_nGnRE = 1
+ _MT_DEVICE_GRE = 2
+ _MT_NORMAL_NC = 3
+ _MT_NORMAL = 4
+ _MT_NORMAL_WT = 5
+ _MT_EL1_INIT = (0 << _MT_DEVICE_nGnRnE) | (0x4 << _MT_DEVICE_nGnRE * 8) | (0xc << _MT_DEVICE_GRE * 8) | (0x44 << _MT_NORMAL_NC * 8) | (0xff << _MT_NORMAL * 8) | (0xbb << _MT_NORMAL_WT * 8)
+)
+
+const (
+ _KVM_ARM_VCPU_POWER_OFF = 0 // CPU is started in OFF state
+ _KVM_ARM_VCPU_PSCI_0_2 = 2 // CPU uses PSCI v0.2
+)
+
+// Arm64: Exception Syndrome Register EL1.
+const (
+ _ESR_ELx_FSC = 0x3F
+
+ _ESR_SEGV_MAPERR_L0 = 0x4
+ _ESR_SEGV_MAPERR_L1 = 0x5
+ _ESR_SEGV_MAPERR_L2 = 0x6
+ _ESR_SEGV_MAPERR_L3 = 0x7
+
+ _ESR_SEGV_ACCERR_L1 = 0x9
+ _ESR_SEGV_ACCERR_L2 = 0xa
+ _ESR_SEGV_ACCERR_L3 = 0xb
+
+ _ESR_SEGV_PEMERR_L1 = 0xd
+ _ESR_SEGV_PEMERR_L2 = 0xe
+ _ESR_SEGV_PEMERR_L3 = 0xf
+)
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index cc6c138b2..7d02ebf19 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -215,6 +215,17 @@ func newMachine(vm int) (*machine, error) {
return true // Keep iterating.
})
+ var physicalRegionsReadOnly []physicalRegion
+ var physicalRegionsAvailable []physicalRegion
+
+ physicalRegionsReadOnly = rdonlyRegionsForSetMem()
+ physicalRegionsAvailable = availableRegionsForSetMem()
+
+ // Map all read-only regions.
+ for _, r := range physicalRegionsReadOnly {
+ m.mapPhysical(r.physical, r.length, physicalRegionsReadOnly, _KVM_MEM_READONLY)
+ }
+
// Ensure that the currently mapped virtual regions are actually
// available in the VM. Note that this doesn't guarantee no future
// faults, however it should guarantee that everything is available to
@@ -223,6 +234,13 @@ func newMachine(vm int) (*machine, error) {
if excludeVirtualRegion(vr) {
return // skip region.
}
+
+ for _, r := range physicalRegionsReadOnly {
+ if vr.virtual == r.virtual {
+ return
+ }
+ }
+
for virtual := vr.virtual; virtual < vr.virtual+vr.length; {
physical, length, ok := translateToPhysical(virtual)
if !ok {
@@ -236,7 +254,7 @@ func newMachine(vm int) (*machine, error) {
}
// Ensure the physical range is mapped.
- m.mapPhysical(physical, length)
+ m.mapPhysical(physical, length, physicalRegionsAvailable, _KVM_MEM_FLAGS_NONE)
virtual += length
}
})
@@ -256,9 +274,9 @@ func newMachine(vm int) (*machine, error) {
// not available. This attempts to be efficient for calls in the hot path.
//
// This panics on error.
-func (m *machine) mapPhysical(physical, length uintptr) {
+func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalRegion, flags uint32) {
for end := physical + length; physical < end; {
- _, physicalStart, length, ok := calculateBluepillFault(physical)
+ _, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions)
if !ok {
// Should never happen.
panic("mapPhysical on unknown physical address")
@@ -266,7 +284,7 @@ func (m *machine) mapPhysical(physical, length uintptr) {
if _, ok := m.mappingCache.LoadOrStore(physicalStart, true); !ok {
// Not present in the cache; requires setting the slot.
- if _, ok := handleBluepillFault(m, physical); !ok {
+ if _, ok := handleBluepillFault(m, physical, phyRegions, flags); !ok {
panic("handleBluepillFault failed")
}
}
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index c1cbe33be..b99fe425e 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -355,3 +355,13 @@ func (m *machine) retryInGuest(fn func()) {
}
}
}
+
+// On x86 platform, the flags for "setMemoryRegion" can always be set as 0.
+// There is no need to return read-only physicalRegions.
+func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) {
+ return nil
+}
+
+func availableRegionsForSetMem() (phyRegions []physicalRegion) {
+ return physicalRegions
+}
diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
index 506ec9af1..7156c245f 100644
--- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
@@ -26,30 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/time"
)
-// setMemoryRegion initializes a region.
-//
-// This may be called from bluepillHandler, and therefore returns an errno
-// directly (instead of wrapping in an error) to avoid allocations.
-//
-//go:nosplit
-func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr) syscall.Errno {
- userRegion := userMemoryRegion{
- slot: uint32(slot),
- flags: 0,
- guestPhysAddr: uint64(physical),
- memorySize: uint64(length),
- userspaceAddr: uint64(virtual),
- }
-
- // Set the region.
- _, _, errno := syscall.RawSyscall(
- syscall.SYS_IOCTL,
- uintptr(m.fd),
- _KVM_SET_USER_MEMORY_REGION,
- uintptr(unsafe.Pointer(&userRegion)))
- return errno
-}
-
// loadSegments copies the current segments.
//
// This may be called from within the signal context and throws on error.
@@ -159,3 +135,43 @@ func (c *vCPU) setSignalMask() error {
}
return nil
}
+
+// setUserRegisters sets user registers in the vCPU.
+func (c *vCPU) setUserRegisters(uregs *userRegs) error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_REGS,
+ uintptr(unsafe.Pointer(uregs))); errno != 0 {
+ return fmt.Errorf("error setting user registers: %v", errno)
+ }
+ return nil
+}
+
+// getUserRegisters reloads user registers in the vCPU.
+//
+// This is safe to call from a nosplit context.
+//
+//go:nosplit
+func (c *vCPU) getUserRegisters(uregs *userRegs) syscall.Errno {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_GET_REGS,
+ uintptr(unsafe.Pointer(uregs))); errno != 0 {
+ return errno
+ }
+ return 0
+}
+
+// setSystemRegisters sets system registers.
+func (c *vCPU) setSystemRegisters(sregs *systemRegs) error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_SREGS,
+ uintptr(unsafe.Pointer(sregs))); errno != 0 {
+ return fmt.Errorf("error setting system registers: %v", errno)
+ }
+ return nil
+}
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
new file mode 100644
index 000000000..7ae47f291
--- /dev/null
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -0,0 +1,183 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package kvm
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+)
+
+type vCPUArchState struct {
+ // PCIDs is the set of PCIDs for this vCPU.
+ //
+ // This starts above fixedKernelPCID.
+ PCIDs *pagetables.PCIDs
+}
+
+const (
+ // fixedKernelPCID is a fixed kernel PCID used for the kernel page
+ // tables. We must start allocating user PCIDs above this in order to
+ // avoid any conflict (see below).
+ fixedKernelPCID = 1
+
+ // poolPCIDs is the number of PCIDs to record in the database. As this
+ // grows, assignment can take longer, since it is a simple linear scan.
+ // Beyond a relatively small number, there are likely few perform
+ // benefits, since the TLB has likely long since lost any translations
+ // from more than a few PCIDs past.
+ poolPCIDs = 8
+)
+
+// Get all read-only physicalRegions.
+func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) {
+ var rdonlyRegions []region
+
+ applyVirtualRegions(func(vr virtualRegion) {
+ if excludeVirtualRegion(vr) {
+ return
+ }
+
+ if !vr.accessType.Write && vr.accessType.Read {
+ rdonlyRegions = append(rdonlyRegions, vr.region)
+ }
+ })
+
+ for _, r := range rdonlyRegions {
+ physical, _, ok := translateToPhysical(r.virtual)
+ if !ok {
+ continue
+ }
+
+ phyRegions = append(phyRegions, physicalRegion{
+ region: region{
+ virtual: r.virtual,
+ length: r.length,
+ },
+ physical: physical,
+ })
+ }
+
+ return phyRegions
+}
+
+// Get all available physicalRegions.
+func availableRegionsForSetMem() (phyRegions []physicalRegion) {
+ var excludeRegions []region
+ applyVirtualRegions(func(vr virtualRegion) {
+ if !vr.accessType.Write {
+ excludeRegions = append(excludeRegions, vr.region)
+ }
+ })
+
+ phyRegions = computePhysicalRegions(excludeRegions)
+
+ return phyRegions
+}
+
+// dropPageTables drops cached page table entries.
+func (m *machine) dropPageTables(pt *pagetables.PageTables) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ // Clear from all PCIDs.
+ for _, c := range m.vCPUs {
+ c.PCIDs.Drop(pt)
+ }
+}
+
+// nonCanonical generates a canonical address return.
+//
+//go:nosplit
+func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (usermem.AccessType, error) {
+ *info = arch.SignalInfo{
+ Signo: signal,
+ Code: arch.SignalInfoKernel,
+ }
+ info.SetAddr(addr) // Include address.
+ return usermem.NoAccess, platform.ErrContextSignal
+}
+
+// fault generates an appropriate fault return.
+//
+//go:nosplit
+func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (usermem.AccessType, error) {
+ faultAddr := c.GetFaultAddr()
+ code, user := c.ErrorCode()
+
+ // Reset the pointed SignalInfo.
+ *info = arch.SignalInfo{Signo: signal}
+ info.SetAddr(uint64(faultAddr))
+
+ read := true
+ write := false
+ execute := true
+
+ ret := code & _ESR_ELx_FSC
+ switch ret {
+ case _ESR_SEGV_MAPERR_L0, _ESR_SEGV_MAPERR_L1, _ESR_SEGV_MAPERR_L2, _ESR_SEGV_MAPERR_L3:
+ info.Code = 1 //SEGV_MAPERR
+ read = false
+ write = true
+ execute = false
+ case _ESR_SEGV_ACCERR_L1, _ESR_SEGV_ACCERR_L2, _ESR_SEGV_ACCERR_L3, _ESR_SEGV_PEMERR_L1, _ESR_SEGV_PEMERR_L2, _ESR_SEGV_PEMERR_L3:
+ info.Code = 2 // SEGV_ACCERR.
+ read = true
+ write = false
+ execute = false
+ default:
+ info.Code = 2
+ }
+
+ if !user {
+ read = true
+ write = false
+ execute = true
+
+ }
+ accessType := usermem.AccessType{
+ Read: read,
+ Write: write,
+ Execute: execute,
+ }
+
+ return accessType, platform.ErrContextSignal
+}
+
+// retryInGuest runs the given function in guest mode.
+//
+// If the function does not complete in guest mode (due to execution of a
+// system call due to a GC stall, for example), then it will be retried. The
+// given function must be idempotent as a result of the retry mechanism.
+func (m *machine) retryInGuest(fn func()) {
+ c := m.Get()
+ defer m.Put(c)
+ for {
+ c.ClearErrorCode() // See below.
+ bluepill(c) // Force guest mode.
+ fn() // Execute the given function.
+ _, user := c.ErrorCode()
+ if user {
+ // If user is set, then we haven't bailed back to host
+ // mode via a kernel exception or system call. We
+ // consider the full function to have executed in guest
+ // mode and we can return.
+ break
+ }
+ }
+}
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
new file mode 100644
index 000000000..3f2f97a6b
--- /dev/null
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -0,0 +1,362 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package kvm
+
+import (
+ "fmt"
+ "reflect"
+ "sync/atomic"
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+)
+
+// setMemoryRegion initializes a region.
+//
+// This may be called from bluepillHandler, and therefore returns an errno
+// directly (instead of wrapping in an error) to avoid allocations.
+//
+//go:nosplit
+func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr) syscall.Errno {
+ userRegion := userMemoryRegion{
+ slot: uint32(slot),
+ flags: 0,
+ guestPhysAddr: uint64(physical),
+ memorySize: uint64(length),
+ userspaceAddr: uint64(virtual),
+ }
+
+ // Set the region.
+ _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(m.fd),
+ _KVM_SET_USER_MEMORY_REGION,
+ uintptr(unsafe.Pointer(&userRegion)))
+ return errno
+}
+
+type kvmVcpuInit struct {
+ target uint32
+ features [7]uint32
+}
+
+var vcpuInit kvmVcpuInit
+
+// initArchState initializes architecture-specific state.
+func (m *machine) initArchState() error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(m.fd),
+ _KVM_ARM_PREFERRED_TARGET,
+ uintptr(unsafe.Pointer(&vcpuInit))); errno != 0 {
+ panic(fmt.Sprintf("error setting KVM_ARM_PREFERRED_TARGET failed: %v", errno))
+ }
+ return nil
+}
+
+func getPageWithReflect(p uintptr) []byte {
+ return (*(*[0xFFFFFF]byte)(unsafe.Pointer(p & ^uintptr(syscall.Getpagesize()-1))))[:syscall.Getpagesize()]
+}
+
+// Work around: move ring0.Vectors() into a specific address with 11-bits alignment.
+//
+// According to the design documentation of Arm64,
+// the start address of exception vector table should be 11-bits aligned.
+// Please see the code in linux kernel as reference: arch/arm64/kernel/entry.S
+// But, we can't align a function's start address to a specific address by using golang.
+// We have raised this question in golang community:
+// https://groups.google.com/forum/m/#!topic/golang-dev/RPj90l5x86I
+// This function will be removed when golang supports this feature.
+//
+// There are 2 jobs were implemented in this function:
+// 1, move the start address of exception vector table into the specific address.
+// 2, modify the offset of each instruction.
+func updateVectorTable() {
+ fromLocation := reflect.ValueOf(ring0.Vectors).Pointer()
+ offset := fromLocation & (1<<11 - 1)
+ if offset != 0 {
+ offset = 1<<11 - offset
+ }
+
+ toLocation := fromLocation + offset
+ page := getPageWithReflect(toLocation)
+ if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC); err != nil {
+ panic(err)
+ }
+
+ page = getPageWithReflect(toLocation + 4096)
+ if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC); err != nil {
+ panic(err)
+ }
+
+ // Move exception-vector-table into the specific address.
+ var entry *uint32
+ var entryFrom *uint32
+ for i := 1; i <= 0x800; i++ {
+ entry = (*uint32)(unsafe.Pointer(toLocation + 0x800 - uintptr(i)))
+ entryFrom = (*uint32)(unsafe.Pointer(fromLocation + 0x800 - uintptr(i)))
+ *entry = *entryFrom
+ }
+
+ // The offset from the address of each unconditionally branch is changed.
+ // We should modify the offset of each instruction.
+ nums := []uint32{0x0, 0x80, 0x100, 0x180, 0x200, 0x280, 0x300, 0x380, 0x400, 0x480, 0x500, 0x580, 0x600, 0x680, 0x700, 0x780}
+ for _, num := range nums {
+ entry = (*uint32)(unsafe.Pointer(toLocation + uintptr(num)))
+ *entry = *entry - (uint32)(offset/4)
+ }
+
+ page = getPageWithReflect(toLocation)
+ if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_EXEC); err != nil {
+ panic(err)
+ }
+
+ page = getPageWithReflect(toLocation + 4096)
+ if err := syscall.Mprotect(page, syscall.PROT_READ|syscall.PROT_EXEC); err != nil {
+ panic(err)
+ }
+}
+
+// initArchState initializes architecture-specific state.
+func (c *vCPU) initArchState() error {
+ var (
+ reg kvmOneReg
+ data uint64
+ regGet kvmOneReg
+ dataGet uint64
+ )
+
+ reg.addr = uint64(reflect.ValueOf(&data).Pointer())
+ regGet.addr = uint64(reflect.ValueOf(&dataGet).Pointer())
+
+ vcpuInit.features[0] |= (1 << _KVM_ARM_VCPU_PSCI_0_2)
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_ARM_VCPU_INIT,
+ uintptr(unsafe.Pointer(&vcpuInit))); errno != 0 {
+ panic(fmt.Sprintf("error setting KVM_ARM_VCPU_INIT failed: %v", errno))
+ }
+
+ // cpacr_el1
+ reg.id = _KVM_ARM64_REGS_CPACR_EL1
+ data = (_FPEN_NOTRAP << _FPEN_SHIFT)
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // sctlr_el1
+ regGet.id = _KVM_ARM64_REGS_SCTLR_EL1
+ if err := c.getOneRegister(&regGet); err != nil {
+ return err
+ }
+
+ dataGet |= (_SCTLR_M | _SCTLR_C | _SCTLR_I)
+ data = dataGet
+ reg.id = _KVM_ARM64_REGS_SCTLR_EL1
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // tcr_el1
+ data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS
+ reg.id = _KVM_ARM64_REGS_TCR_EL1
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // mair_el1
+ data = _MT_EL1_INIT
+ reg.id = _KVM_ARM64_REGS_MAIR_EL1
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // ttbr0_el1
+ data = c.machine.kernel.PageTables.TTBR0_EL1(false, 0)
+
+ reg.id = _KVM_ARM64_REGS_TTBR0_EL1
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ c.SetTtbr0Kvm(uintptr(data))
+
+ // ttbr1_el1
+ data = c.machine.kernel.PageTables.TTBR1_EL1(false, 0)
+
+ reg.id = _KVM_ARM64_REGS_TTBR1_EL1
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // sp_el1
+ data = c.CPU.StackTop()
+ reg.id = _KVM_ARM64_REGS_SP_EL1
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // pc
+ reg.id = _KVM_ARM64_REGS_PC
+ data = uint64(reflect.ValueOf(ring0.Start).Pointer())
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // r8
+ reg.id = _KVM_ARM64_REGS_R8
+ data = uint64(reflect.ValueOf(&c.CPU).Pointer())
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // vbar_el1
+ reg.id = _KVM_ARM64_REGS_VBAR_EL1
+
+ fromLocation := reflect.ValueOf(ring0.Vectors).Pointer()
+ offset := fromLocation & (1<<11 - 1)
+ if offset != 0 {
+ offset = 1<<11 - offset
+ }
+
+ toLocation := fromLocation + offset
+ data = uint64(ring0.KernelStartAddress | toLocation)
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ data = ring0.PsrDefaultSet | ring0.KernelFlagsSet
+ reg.id = _KVM_ARM64_REGS_PSTATE
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+//go:nosplit
+func (c *vCPU) loadSegments(tid uint64) {
+ // TODO(gvisor.dev/issue/1238): TLS is not supported.
+ // Get TLS from tpidr_el0.
+ atomic.StoreUint64(&c.tid, tid)
+}
+
+func (c *vCPU) setOneRegister(reg *kvmOneReg) error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_ONE_REG,
+ uintptr(unsafe.Pointer(reg))); errno != 0 {
+ return fmt.Errorf("error setting one register: %v", errno)
+ }
+ return nil
+}
+
+func (c *vCPU) getOneRegister(reg *kvmOneReg) error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_GET_ONE_REG,
+ uintptr(unsafe.Pointer(reg))); errno != 0 {
+ return fmt.Errorf("error setting one register: %v", errno)
+ }
+ return nil
+}
+
+// setCPUID sets the CPUID to be used by the guest.
+func (c *vCPU) setCPUID() error {
+ return nil
+}
+
+// setSystemTime sets the TSC for the vCPU.
+func (c *vCPU) setSystemTime() error {
+ return nil
+}
+
+// setSignalMask sets the vCPU signal mask.
+//
+// This must be called prior to running the vCPU.
+func (c *vCPU) setSignalMask() error {
+ // The layout of this structure implies that it will not necessarily be
+ // the same layout chosen by the Go compiler. It gets fudged here.
+ var data struct {
+ length uint32
+ mask1 uint32
+ mask2 uint32
+ _ uint32
+ }
+ data.length = 8 // Fixed sigset size.
+ data.mask1 = ^uint32(bounceSignalMask & 0xffffffff)
+ data.mask2 = ^uint32(bounceSignalMask >> 32)
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_SIGNAL_MASK,
+ uintptr(unsafe.Pointer(&data))); errno != 0 {
+ return fmt.Errorf("error setting signal mask: %v", errno)
+ }
+
+ return nil
+}
+
+// SwitchToUser unpacks architectural-details.
+func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (usermem.AccessType, error) {
+ // Check for canonical addresses.
+ if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Pc) {
+ return nonCanonical(regs.Pc, int32(syscall.SIGSEGV), info)
+ } else if !ring0.IsCanonical(regs.Sp) {
+ return nonCanonical(regs.Sp, int32(syscall.SIGBUS), info)
+ }
+
+ var vector ring0.Vector
+ ttbr0App := switchOpts.PageTables.TTBR0_EL1(false, 0)
+ c.SetTtbr0App(uintptr(ttbr0App))
+
+ // TODO(gvisor.dev/issue/1238): full context-switch supporting for Arm64.
+ // The Arm64 user-mode execution state consists of:
+ // x0-x30
+ // PC, SP, PSTATE
+ // V0-V31: 32 128-bit registers for floating point, and simd
+ // FPSR
+ // TPIDR_EL0, used for TLS
+ appRegs := switchOpts.Registers
+ c.SetAppAddr(ring0.KernelStartAddress | uintptr(unsafe.Pointer(appRegs)))
+
+ entersyscall()
+ bluepill(c)
+ vector = c.CPU.SwitchToUser(switchOpts)
+ exitsyscall()
+
+ switch vector {
+ case ring0.Syscall:
+ // Fast path: system call executed.
+ return usermem.NoAccess, nil
+
+ case ring0.PageFault:
+ return c.fault(int32(syscall.SIGSEGV), info)
+ case 0xaa:
+ return usermem.NoAccess, nil
+ default:
+ return usermem.NoAccess, platform.ErrContextSignal
+ }
+
+}
diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go
index 405e00292..f04be2ab5 100644
--- a/pkg/sentry/platform/kvm/machine_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_unsafe.go
@@ -13,7 +13,7 @@
// limitations under the License.
// +build go1.12
-// +build !go1.14
+// +build !go1.15
// Check go:linkname function signatures when updating Go version.
@@ -35,6 +35,30 @@ func entersyscall()
//go:linkname exitsyscall runtime.exitsyscall
func exitsyscall()
+// setMemoryRegion initializes a region.
+//
+// This may be called from bluepillHandler, and therefore returns an errno
+// directly (instead of wrapping in an error) to avoid allocations.
+//
+//go:nosplit
+func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr, flags uint32) syscall.Errno {
+ userRegion := userMemoryRegion{
+ slot: uint32(slot),
+ flags: uint32(flags),
+ guestPhysAddr: uint64(physical),
+ memorySize: uint64(length),
+ userspaceAddr: uint64(virtual),
+ }
+
+ // Set the region.
+ _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(m.fd),
+ _KVM_SET_USER_MEMORY_REGION,
+ uintptr(unsafe.Pointer(&userRegion)))
+ return errno
+}
+
// mapRunData maps the vCPU run data.
func mapRunData(fd int) (*runData, error) {
r, _, errno := syscall.RawSyscall6(
@@ -63,46 +87,6 @@ func unmapRunData(r *runData) error {
return nil
}
-// setUserRegisters sets user registers in the vCPU.
-func (c *vCPU) setUserRegisters(uregs *userRegs) error {
- if _, _, errno := syscall.RawSyscall(
- syscall.SYS_IOCTL,
- uintptr(c.fd),
- _KVM_SET_REGS,
- uintptr(unsafe.Pointer(uregs))); errno != 0 {
- return fmt.Errorf("error setting user registers: %v", errno)
- }
- return nil
-}
-
-// getUserRegisters reloads user registers in the vCPU.
-//
-// This is safe to call from a nosplit context.
-//
-//go:nosplit
-func (c *vCPU) getUserRegisters(uregs *userRegs) syscall.Errno {
- if _, _, errno := syscall.RawSyscall(
- syscall.SYS_IOCTL,
- uintptr(c.fd),
- _KVM_GET_REGS,
- uintptr(unsafe.Pointer(uregs))); errno != 0 {
- return errno
- }
- return 0
-}
-
-// setSystemRegisters sets system registers.
-func (c *vCPU) setSystemRegisters(sregs *systemRegs) error {
- if _, _, errno := syscall.RawSyscall(
- syscall.SYS_IOCTL,
- uintptr(c.fd),
- _KVM_SET_SREGS,
- uintptr(unsafe.Pointer(sregs))); errno != 0 {
- return fmt.Errorf("error setting system registers: %v", errno)
- }
- return nil
-}
-
// atomicAddressSpace is an atomic address space pointer.
type atomicAddressSpace struct {
pointer unsafe.Pointer
diff --git a/pkg/sentry/platform/kvm/physical_map.go b/pkg/sentry/platform/kvm/physical_map.go
index 586e91bb2..91de5dab1 100644
--- a/pkg/sentry/platform/kvm/physical_map.go
+++ b/pkg/sentry/platform/kvm/physical_map.go
@@ -24,15 +24,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/usermem"
)
-const (
- // reservedMemory is a chunk of physical memory reserved starting at
- // physical address zero. There are some special pages in this region,
- // so we just call the whole thing off.
- //
- // Other architectures may define this to be zero.
- reservedMemory = 0x100000000
-)
-
type region struct {
virtual uintptr
length uintptr
@@ -59,8 +50,7 @@ func fillAddressSpace() (excludedRegions []region) {
// We can cut vSize in half, because the kernel will be using the top
// half and we ignore it while constructing mappings. It's as if we've
// already excluded half the possible addresses.
- vSize := uintptr(1) << ring0.VirtualAddressBits()
- vSize = vSize >> 1
+ vSize := ring0.UserspaceSize
// We exclude reservedMemory below from our physical memory size, so it
// needs to be dropped here as well. Otherwise, we could end up with
diff --git a/pkg/sentry/platform/kvm/physical_map_amd64.go b/pkg/sentry/platform/kvm/physical_map_amd64.go
new file mode 100644
index 000000000..c5adfb577
--- /dev/null
+++ b/pkg/sentry/platform/kvm/physical_map_amd64.go
@@ -0,0 +1,22 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+const (
+ // reservedMemory is a chunk of physical memory reserved starting at
+ // physical address zero. There are some special pages in this region,
+ // so we just call the whole thing off.
+ reservedMemory = 0x100000000
+)
diff --git a/pkg/sentry/platform/kvm/physical_map_arm64.go b/pkg/sentry/platform/kvm/physical_map_arm64.go
new file mode 100644
index 000000000..4d8561453
--- /dev/null
+++ b/pkg/sentry/platform/kvm/physical_map_arm64.go
@@ -0,0 +1,19 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+const (
+ reservedMemory = 0
+)
diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
index 2cd28b2d2..0bebee852 100644
--- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
+++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
@@ -50,6 +50,21 @@ TEXT ·SpinLoop(SB),NOSPLIT,$0
start:
B start
+TEXT ·FloatingPointWorks(SB),NOSPLIT,$0-8
+ NO_LOCAL_POINTERS
+ FMOVD $(9.9), F0
+ MOVD $SYS_GETPID, R8 // getpid
+ SVC
+ FMOVD $(9.9), F1
+ FCMPD F0, F1
+ BNE isNaN
+ MOVD $1, R0
+ MOVD R0, ret+0(FP)
+ RET
+isNaN:
+ MOVD $0, ret+0(FP)
+ RET
+
// MVN: bitwise logical NOT
// This case simulates an application that modified R0-R30.
#define TWIDDLE_REGS() \
diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD
index ebcc8c098..0df8cfa0f 100644
--- a/pkg/sentry/platform/ptrace/BUILD
+++ b/pkg/sentry/platform/ptrace/BUILD
@@ -28,6 +28,7 @@ go_library(
"//pkg/procid",
"//pkg/seccomp",
"//pkg/sentry/arch",
+ "//pkg/sentry/hostcpu",
"//pkg/sentry/platform",
"//pkg/sentry/platform/interrupt",
"//pkg/sentry/platform/safecopy",
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 9f0ecfbe4..ddb1f41e3 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -327,6 +327,20 @@ func (t *thread) dumpAndPanic(message string) {
panic(message)
}
+func (t *thread) unexpectedStubExit() {
+ msg, err := t.getEventMessage()
+ status := syscall.WaitStatus(msg)
+ if status.Signaled() && status.Signal() == syscall.SIGKILL {
+ // SIGKILL can be only sent by an user or OOM-killer. In both
+ // these cases, we don't need to panic. There is no reasons to
+ // think that something wrong in gVisor.
+ log.Warningf("The ptrace stub process %v has been killed by SIGKILL.", t.tgid)
+ pid := os.Getpid()
+ syscall.Tgkill(pid, pid, syscall.Signal(syscall.SIGKILL))
+ }
+ t.dumpAndPanic(fmt.Sprintf("wait failed: the process %d:%d exited: %x (err %v)", t.tgid, t.tid, msg, err))
+}
+
// wait waits for a stop event.
//
// Precondition: outcome is a valid waitOutcome.
@@ -355,8 +369,7 @@ func (t *thread) wait(outcome waitOutcome) syscall.Signal {
}
if stopSig == syscall.SIGTRAP {
if status.TrapCause() == syscall.PTRACE_EVENT_EXIT {
- msg, err := t.getEventMessage()
- t.dumpAndPanic(fmt.Sprintf("wait failed: the process %d:%d exited: %x (err %v)", t.tgid, t.tid, msg, err))
+ t.unexpectedStubExit()
}
// Re-encode the trap cause the way it's expected.
return stopSig | syscall.Signal(status.TrapCause()<<8)
diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go
index 4649a94a7..606dc2b1d 100644
--- a/pkg/sentry/platform/ptrace/subprocess_amd64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go
@@ -21,6 +21,8 @@ import (
"strings"
"syscall"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/seccomp"
"gvisor.dev/gvisor/pkg/sentry/arch"
)
@@ -143,3 +145,49 @@ func (t *thread) adjustInitRegsRip() {
func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) {
initregs.R15 = uint64(ppid)
}
+
+// patchSignalInfo patches the signal info to account for hitting the seccomp
+// filters from vsyscall emulation, specified below. We allow for SIGSYS as a
+// synchronous trap, but patch the structure to appear like a SIGSEGV with the
+// Rip as the faulting address.
+//
+// Note that this should only be called after verifying that the signalInfo has
+// been generated by the kernel.
+func patchSignalInfo(regs *syscall.PtraceRegs, signalInfo *arch.SignalInfo) {
+ if linux.Signal(signalInfo.Signo) == linux.SIGSYS {
+ signalInfo.Signo = int32(linux.SIGSEGV)
+
+ // Unwind the kernel emulation, if any has occurred. A SIGSYS is delivered
+ // with the si_call_addr field pointing to the current RIP. This field
+ // aligns with the si_addr field for a SIGSEGV, so we don't need to touch
+ // anything there. We do need to unwind emulation however, so we set the
+ // instruction pointer to the faulting value, and "unpop" the stack.
+ regs.Rip = signalInfo.Addr()
+ regs.Rsp -= 8
+ }
+}
+
+// enableCpuidFault enables cpuid-faulting.
+//
+// This may fail on older kernels or hardware, so we just disregard the result.
+// Host CPUID will be enabled.
+//
+// This is safe to call in an afterFork context.
+//
+//go:nosplit
+func enableCpuidFault() {
+ syscall.RawSyscall6(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0, 0, 0, 0)
+}
+
+// appendArchSeccompRules append architecture specific seccomp rules when creating BPF program.
+// Ref attachedThread() for more detail.
+func appendArchSeccompRules(rules []seccomp.RuleSet) []seccomp.RuleSet {
+ return append(rules, seccomp.RuleSet{
+ Rules: seccomp.SyscallRules{
+ syscall.SYS_ARCH_PRCTL: []seccomp.Rule{
+ {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)},
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ })
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go
index bec884ba5..62a686ee7 100644
--- a/pkg/sentry/platform/ptrace/subprocess_arm64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go
@@ -17,8 +17,12 @@
package ptrace
import (
+ "fmt"
+ "strings"
"syscall"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/seccomp"
"gvisor.dev/gvisor/pkg/sentry/arch"
)
@@ -37,7 +41,7 @@ const (
// resetSysemuRegs sets up emulation registers.
//
// This should be called prior to calling sysemu.
-func (s *subprocess) resetSysemuRegs(regs *syscall.PtraceRegs) {
+func (t *thread) resetSysemuRegs(regs *syscall.PtraceRegs) {
}
// createSyscallRegs sets up syscall registers.
@@ -124,3 +128,36 @@ func (t *thread) adjustInitRegsRip() {
func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) {
initregs.Regs[7] = uint64(ppid)
}
+
+// patchSignalInfo patches the signal info to account for hitting the seccomp
+// filters from vsyscall emulation, specified below. We allow for SIGSYS as a
+// synchronous trap, but patch the structure to appear like a SIGSEGV with the
+// Rip as the faulting address.
+//
+// Note that this should only be called after verifying that the signalInfo has
+// been generated by the kernel.
+func patchSignalInfo(regs *syscall.PtraceRegs, signalInfo *arch.SignalInfo) {
+ if linux.Signal(signalInfo.Signo) == linux.SIGSYS {
+ signalInfo.Signo = int32(linux.SIGSEGV)
+
+ // Unwind the kernel emulation, if any has occurred. A SIGSYS is delivered
+ // with the si_call_addr field pointing to the current RIP. This field
+ // aligns with the si_addr field for a SIGSEGV, so we don't need to touch
+ // anything there. We do need to unwind emulation however, so we set the
+ // instruction pointer to the faulting value, and "unpop" the stack.
+ regs.Pc = signalInfo.Addr()
+ regs.Sp -= 8
+ }
+}
+
+// Noop on arm64.
+//
+//go:nosplit
+func enableCpuidFault() {
+}
+
+// appendArchSeccompRules append architecture specific seccomp rules when creating BPF program.
+// Ref attachedThread() for more detail.
+func appendArchSeccompRules(rules []seccomp.RuleSet) []seccomp.RuleSet {
+ return rules
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go
index c075b5f91..cf13ea5e4 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux.go
@@ -20,6 +20,7 @@ import (
"fmt"
"syscall"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/procid"
@@ -77,27 +78,6 @@ func probeSeccomp() bool {
}
}
-// patchSignalInfo patches the signal info to account for hitting the seccomp
-// filters from vsyscall emulation, specified below. We allow for SIGSYS as a
-// synchronous trap, but patch the structure to appear like a SIGSEGV with the
-// Rip as the faulting address.
-//
-// Note that this should only be called after verifying that the signalInfo has
-// been generated by the kernel.
-func patchSignalInfo(regs *syscall.PtraceRegs, signalInfo *arch.SignalInfo) {
- if linux.Signal(signalInfo.Signo) == linux.SIGSYS {
- signalInfo.Signo = int32(linux.SIGSEGV)
-
- // Unwind the kernel emulation, if any has occurred. A SIGSYS is delivered
- // with the si_call_addr field pointing to the current RIP. This field
- // aligns with the si_addr field for a SIGSEGV, so we don't need to touch
- // anything there. We do need to unwind emulation however, so we set the
- // instruction pointer to the faulting value, and "unpop" the stack.
- regs.Rip = signalInfo.Addr()
- regs.Rsp -= 8
- }
-}
-
// createStub creates a fresh stub processes.
//
// Precondition: the runtime OS thread must be locked.
@@ -129,6 +109,9 @@ func createStub() (*thread, error) {
// transitively) will be killed as well. It's simply not possible to
// safely handle a single stub getting killed: the exact state of
// execution is unknown and not recoverable.
+ //
+ // In addition, we set the PTRACE_O_TRACEEXIT option to log more
+ // information about a stub process when it receives a fatal signal.
return attachedThread(uintptr(syscall.SIGKILL)|syscall.CLONE_FILES, defaultAction)
}
@@ -146,7 +129,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
Rules: seccomp.SyscallRules{
syscall.SYS_GETTIMEOFDAY: {},
syscall.SYS_TIME: {},
- 309: {}, // SYS_GETCPU.
+ unix.SYS_GETCPU: {}, // SYS_GETCPU was not defined in package syscall on amd64.
},
Action: linux.SECCOMP_RET_TRAP,
Vsyscall: true,
@@ -170,10 +153,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
// For the initial process creation.
syscall.SYS_WAIT4: {},
- syscall.SYS_ARCH_PRCTL: []seccomp.Rule{
- {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)},
- },
- syscall.SYS_EXIT: {},
+ syscall.SYS_EXIT: {},
// For the stub prctl dance (all).
syscall.SYS_PRCTL: []seccomp.Rule{
@@ -193,6 +173,8 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
},
Action: linux.SECCOMP_RET_ALLOW,
})
+
+ rules = appendArchSeccompRules(rules)
}
instrs, err := seccomp.BuildProgram(rules, defaultAction)
if err != nil {
@@ -264,9 +246,8 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
syscall.RawSyscall(syscall.SYS_EXIT, uintptr(errno), 0, 0)
}
- // Enable cpuid-faulting; this may fail on older kernels or hardware,
- // so we just disregard the result. Host CPUID will be enabled.
- syscall.RawSyscall6(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0, 0, 0, 0)
+ // Enable cpuid-faulting.
+ enableCpuidFault()
// Call the stub; should not return.
stubCall(stubStart, ppid)
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go
index de6783fb0..2e6fbe488 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go
@@ -25,6 +25,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/hostcpu"
)
// maskPool contains reusable CPU masks for setting affinity. Unfortunately,
@@ -49,20 +50,6 @@ func unmaskAllSignals() syscall.Errno {
return errno
}
-// getCPU gets the current CPU.
-//
-// Precondition: the current runtime thread should be locked.
-func getCPU() (uint32, error) {
- var cpu uintptr
- if _, _, errno := syscall.RawSyscall(
- unix.SYS_GETCPU,
- uintptr(unsafe.Pointer(&cpu)),
- 0, 0); errno != 0 {
- return 0, errno
- }
- return uint32(cpu), nil
-}
-
// setCPU sets the CPU affinity.
func (t *thread) setCPU(cpu uint32) error {
mask := maskPool.Get().([]uintptr)
@@ -93,10 +80,8 @@ func (t *thread) setCPU(cpu uint32) error {
//
// Precondition: the current runtime thread should be locked.
func (t *thread) bind() {
- currentCPU, err := getCPU()
- if err != nil {
- return
- }
+ currentCPU := hostcpu.GetCPU()
+
if oldCPU := atomic.SwapUint32(&t.cpu, currentCPU); oldCPU != currentCPU {
// Set the affinity on the thread and save the CPU for next
// round; we don't expect CPUs to bounce around too frequently.
diff --git a/pkg/sentry/platform/ptrace/subprocess_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_unsafe.go
index b80a3604d..2ae6b9f9d 100644
--- a/pkg/sentry/platform/ptrace/subprocess_unsafe.go
+++ b/pkg/sentry/platform/ptrace/subprocess_unsafe.go
@@ -13,7 +13,7 @@
// limitations under the License.
// +build go1.12
-// +build !go1.14
+// +build !go1.15
// Check go:linkname function signatures when updating Go version.
diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD
index 48b0ceaec..87f4552b5 100644
--- a/pkg/sentry/platform/ring0/BUILD
+++ b/pkg/sentry/platform/ring0/BUILD
@@ -4,7 +4,7 @@ load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
package(licenses = ["notice"])
go_template(
- name = "defs",
+ name = "defs_amd64",
srcs = [
"defs.go",
"defs_amd64.go",
@@ -14,11 +14,29 @@ go_template(
visibility = [":__subpackages__"],
)
+go_template(
+ name = "defs_arm64",
+ srcs = [
+ "aarch64.go",
+ "defs.go",
+ "defs_arm64.go",
+ "offsets_arm64.go",
+ ],
+ visibility = [":__subpackages__"],
+)
+
go_template_instance(
- name = "defs_impl",
- out = "defs_impl.go",
+ name = "defs_impl_amd64",
+ out = "defs_impl_amd64.go",
package = "ring0",
- template = ":defs",
+ template = ":defs_amd64",
+)
+
+go_template_instance(
+ name = "defs_impl_arm64",
+ out = "defs_impl_arm64.go",
+ package = "ring0",
+ template = ":defs_arm64",
)
genrule(
@@ -29,17 +47,31 @@ genrule(
tools = ["//pkg/sentry/platform/ring0/gen_offsets"],
)
+genrule(
+ name = "entry_impl_arm64",
+ srcs = ["entry_arm64.s"],
+ outs = ["entry_impl_arm64.s"],
+ cmd = "(echo -e '// build +arm64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@",
+ tools = ["//pkg/sentry/platform/ring0/gen_offsets"],
+)
+
go_library(
name = "ring0",
srcs = [
- "defs_impl.go",
+ "defs_impl_amd64.go",
+ "defs_impl_arm64.go",
"entry_amd64.go",
+ "entry_arm64.go",
"entry_impl_amd64.s",
+ "entry_impl_arm64.s",
"kernel.go",
"kernel_amd64.go",
+ "kernel_arm64.go",
"kernel_unsafe.go",
"lib_amd64.go",
"lib_amd64.s",
+ "lib_arm64.go",
+ "lib_arm64.s",
"ring0.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/platform/ring0",
diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go
new file mode 100644
index 000000000..6b078cd1e
--- /dev/null
+++ b/pkg/sentry/platform/ring0/aarch64.go
@@ -0,0 +1,109 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ring0
+
+// Useful bits.
+const (
+ _PGD_PGT_BASE = 0x1000
+ _PGD_PGT_SIZE = 0x1000
+ _PUD_PGT_BASE = 0x2000
+ _PUD_PGT_SIZE = 0x1000
+ _PMD_PGT_BASE = 0x3000
+ _PMD_PGT_SIZE = 0x4000
+ _PTE_PGT_BASE = 0x7000
+ _PTE_PGT_SIZE = 0x1000
+
+ _PSR_MODE_EL0t = 0x0
+ _PSR_MODE_EL1t = 0x4
+ _PSR_MODE_EL1h = 0x5
+ _PSR_EL_MASK = 0xf
+
+ _PSR_D_BIT = 0x200
+ _PSR_A_BIT = 0x100
+ _PSR_I_BIT = 0x80
+ _PSR_F_BIT = 0x40
+)
+
+const (
+ // KernelFlagsSet should always be set in the kernel.
+ KernelFlagsSet = _PSR_MODE_EL1h
+
+ // UserFlagsSet are always set in userspace.
+ UserFlagsSet = _PSR_MODE_EL0t
+
+ KernelFlagsClear = _PSR_EL_MASK
+ UserFlagsClear = _PSR_EL_MASK
+
+ PsrDefaultSet = _PSR_D_BIT | _PSR_A_BIT | _PSR_I_BIT | _PSR_F_BIT
+)
+
+// Vector is an exception vector.
+type Vector uintptr
+
+// Exception vectors.
+const (
+ El1SyncInvalid = iota
+ El1IrqInvalid
+ El1FiqInvalid
+ El1ErrorInvalid
+ El1Sync
+ El1Irq
+ El1Fiq
+ El1Error
+ El0Sync
+ El0Irq
+ El0Fiq
+ El0Error
+ El0Sync_invalid
+ El0Irq_invalid
+ El0Fiq_invalid
+ El0Error_invalid
+ El1Sync_da
+ El1Sync_ia
+ El1Sync_sp_pc
+ El1Sync_undef
+ El1Sync_dbg
+ El1Sync_inv
+ El0Sync_svc
+ El0Sync_da
+ El0Sync_ia
+ El0Sync_fpsimd_acc
+ El0Sync_sve_acc
+ El0Sync_sys
+ El0Sync_sp_pc
+ El0Sync_undef
+ El0Sync_dbg
+ El0Sync_inv
+ VirtualizationException
+ _NR_INTERRUPTS
+)
+
+// System call vectors.
+const (
+ Syscall Vector = El0Sync_svc
+ PageFault Vector = El0Sync_da
+)
+
+// VirtualAddressBits returns the number bits available for virtual addresses.
+func VirtualAddressBits() uint32 {
+ return 48
+}
+
+// PhysicalAddressBits returns the number of bits available for physical addresses.
+func PhysicalAddressBits() uint32 {
+ return 40
+}
diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go
index 076063f85..3f094c2a7 100644
--- a/pkg/sentry/platform/ring0/defs.go
+++ b/pkg/sentry/platform/ring0/defs.go
@@ -20,17 +20,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/usermem"
)
-var (
- // UserspaceSize is the total size of userspace.
- UserspaceSize = uintptr(1) << (VirtualAddressBits() - 1)
-
- // MaximumUserAddress is the largest possible user address.
- MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1)
-
- // KernelStartAddress is the starting kernel address.
- KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1)
-)
-
// Kernel is a global kernel object.
//
// This contains global state, shared by multiple CPUs.
diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go
index 7206322b1..10dbd381f 100644
--- a/pkg/sentry/platform/ring0/defs_amd64.go
+++ b/pkg/sentry/platform/ring0/defs_amd64.go
@@ -20,6 +20,17 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
)
+var (
+ // UserspaceSize is the total size of userspace.
+ UserspaceSize = uintptr(1) << (VirtualAddressBits() - 1)
+
+ // MaximumUserAddress is the largest possible user address.
+ MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1)
+
+ // KernelStartAddress is the starting kernel address.
+ KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1)
+)
+
// Segment indices and Selectors.
const (
// Index into GDT array.
diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go
new file mode 100644
index 000000000..dc0eeec01
--- /dev/null
+++ b/pkg/sentry/platform/ring0/defs_arm64.go
@@ -0,0 +1,136 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ring0
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+)
+
+var (
+ // UserspaceSize is the total size of userspace.
+ UserspaceSize = uintptr(1) << (VirtualAddressBits())
+
+ // MaximumUserAddress is the largest possible user address.
+ MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1)
+
+ // KernelStartAddress is the starting kernel address.
+ KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1)
+)
+
+// KernelOpts has initialization options for the kernel.
+type KernelOpts struct {
+ // PageTables are the kernel pagetables; this must be provided.
+ PageTables *pagetables.PageTables
+}
+
+// KernelArchState contains architecture-specific state.
+type KernelArchState struct {
+ KernelOpts
+}
+
+// CPUArchState contains CPU-specific arch state.
+type CPUArchState struct {
+ // stack is the stack used for interrupts on this CPU.
+ stack [512]byte
+
+ // errorCode is the error code from the last exception.
+ errorCode uintptr
+
+ // errorType indicates the type of error code here, it is always set
+ // along with the errorCode value above.
+ //
+ // It will either by 1, which indicates a user error, or 0 indicating a
+ // kernel error. If the error code below returns false (kernel error),
+ // then it cannot provide relevant information about the last
+ // exception.
+ errorType uintptr
+
+ // faultAddr is the value of far_el1.
+ faultAddr uintptr
+
+ // ttbr0Kvm is the value of ttbr0_el1 for sentry.
+ ttbr0Kvm uintptr
+
+ // ttbr0App is the value of ttbr0_el1 for applicaton.
+ ttbr0App uintptr
+
+ // exception vector.
+ vecCode Vector
+
+ // application context pointer.
+ appAddr uintptr
+
+ // lazyVFP is the value of cpacr_el1.
+ lazyVFP uintptr
+}
+
+// ErrorCode returns the last error code.
+//
+// The returned boolean indicates whether the error code corresponds to the
+// last user error or not. If it does not, then fault information must be
+// ignored. This is generally the result of a kernel fault while servicing a
+// user fault.
+//
+//go:nosplit
+func (c *CPU) ErrorCode() (value uintptr, user bool) {
+ return c.errorCode, c.errorType != 0
+}
+
+// ClearErrorCode resets the error code.
+//
+//go:nosplit
+func (c *CPU) ClearErrorCode() {
+ c.errorCode = 0 // No code.
+ c.errorType = 1 // User mode.
+}
+
+//go:nosplit
+func (c *CPU) GetFaultAddr() (value uintptr) {
+ return c.faultAddr
+}
+
+//go:nosplit
+func (c *CPU) SetTtbr0Kvm(value uintptr) {
+ c.ttbr0Kvm = value
+}
+
+//go:nosplit
+func (c *CPU) SetTtbr0App(value uintptr) {
+ c.ttbr0App = value
+}
+
+//go:nosplit
+func (c *CPU) GetVector() (value Vector) {
+ return c.vecCode
+}
+
+//go:nosplit
+func (c *CPU) SetAppAddr(value uintptr) {
+ c.appAddr = value
+}
+
+// SwitchArchOpts are embedded in SwitchOpts.
+type SwitchArchOpts struct {
+ // UserASID indicates that the application ASID to be used on switch,
+ UserASID uint16
+
+ // KernelASID indicates that the kernel ASID to be used on return,
+ KernelASID uint16
+}
+
+func init() {
+}
diff --git a/pkg/sentry/platform/ring0/entry_arm64.go b/pkg/sentry/platform/ring0/entry_arm64.go
new file mode 100644
index 000000000..0dfa42c36
--- /dev/null
+++ b/pkg/sentry/platform/ring0/entry_arm64.go
@@ -0,0 +1,60 @@
+// Copyright 2019 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ring0
+
+// This is an assembly function.
+//
+// The sysenter function is invoked in two situations:
+//
+// (1) The guest kernel has executed a system call.
+// (2) The guest application has executed a system call.
+//
+// The interrupt flag is examined to determine whether the system call was
+// executed from kernel mode or not and the appropriate stub is called.
+
+func El1_sync_invalid()
+func El1_irq_invalid()
+func El1_fiq_invalid()
+func El1_error_invalid()
+
+func El1_sync()
+func El1_irq()
+func El1_fiq()
+func El1_error()
+
+func El0_sync()
+func El0_irq()
+func El0_fiq()
+func El0_error()
+
+func El0_sync_invalid()
+func El0_irq_invalid()
+func El0_fiq_invalid()
+func El0_error_invalid()
+
+func Vectors()
+
+// Start is the CPU entrypoint.
+//
+// The CPU state will be set to c.Registers().
+func Start()
+func kernelExitToEl1()
+
+func kernelExitToEl0()
+
+// Shutdown execution
+func Shutdown()
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
new file mode 100644
index 000000000..add2c3e08
--- /dev/null
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -0,0 +1,591 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "funcdata.h"
+#include "textflag.h"
+
+// NB: Offsets are programatically generated (see BUILD).
+//
+// This file is concatenated with the definitions.
+
+// Saves a register set.
+//
+// This is a macro because it may need to executed in contents where a stack is
+// not available for calls.
+//
+
+#define ERET() \
+ WORD $0xd69f03e0
+
+#define RSV_REG R18_PLATFORM
+#define RSV_REG_APP R9
+
+#define FPEN_NOTRAP 0x3
+#define FPEN_SHIFT 20
+
+#define FPEN_ENABLE (FPEN_NOTRAP << FPEN_SHIFT)
+
+#define REGISTERS_SAVE(reg, offset) \
+ MOVD R0, offset+PTRACE_R0(reg); \
+ MOVD R1, offset+PTRACE_R1(reg); \
+ MOVD R2, offset+PTRACE_R2(reg); \
+ MOVD R3, offset+PTRACE_R3(reg); \
+ MOVD R4, offset+PTRACE_R4(reg); \
+ MOVD R5, offset+PTRACE_R5(reg); \
+ MOVD R6, offset+PTRACE_R6(reg); \
+ MOVD R7, offset+PTRACE_R7(reg); \
+ MOVD R8, offset+PTRACE_R8(reg); \
+ MOVD R10, offset+PTRACE_R10(reg); \
+ MOVD R11, offset+PTRACE_R11(reg); \
+ MOVD R12, offset+PTRACE_R12(reg); \
+ MOVD R13, offset+PTRACE_R13(reg); \
+ MOVD R14, offset+PTRACE_R14(reg); \
+ MOVD R15, offset+PTRACE_R15(reg); \
+ MOVD R16, offset+PTRACE_R16(reg); \
+ MOVD R17, offset+PTRACE_R17(reg); \
+ MOVD R19, offset+PTRACE_R19(reg); \
+ MOVD R20, offset+PTRACE_R20(reg); \
+ MOVD R21, offset+PTRACE_R21(reg); \
+ MOVD R22, offset+PTRACE_R22(reg); \
+ MOVD R23, offset+PTRACE_R23(reg); \
+ MOVD R24, offset+PTRACE_R24(reg); \
+ MOVD R25, offset+PTRACE_R25(reg); \
+ MOVD R26, offset+PTRACE_R26(reg); \
+ MOVD R27, offset+PTRACE_R27(reg); \
+ MOVD g, offset+PTRACE_R28(reg); \
+ MOVD R29, offset+PTRACE_R29(reg); \
+ MOVD R30, offset+PTRACE_R30(reg);
+
+#define REGISTERS_LOAD(reg, offset) \
+ MOVD offset+PTRACE_R0(reg), R0; \
+ MOVD offset+PTRACE_R1(reg), R1; \
+ MOVD offset+PTRACE_R2(reg), R2; \
+ MOVD offset+PTRACE_R3(reg), R3; \
+ MOVD offset+PTRACE_R4(reg), R4; \
+ MOVD offset+PTRACE_R5(reg), R5; \
+ MOVD offset+PTRACE_R6(reg), R6; \
+ MOVD offset+PTRACE_R7(reg), R7; \
+ MOVD offset+PTRACE_R8(reg), R8; \
+ MOVD offset+PTRACE_R10(reg), R10; \
+ MOVD offset+PTRACE_R11(reg), R11; \
+ MOVD offset+PTRACE_R12(reg), R12; \
+ MOVD offset+PTRACE_R13(reg), R13; \
+ MOVD offset+PTRACE_R14(reg), R14; \
+ MOVD offset+PTRACE_R15(reg), R15; \
+ MOVD offset+PTRACE_R16(reg), R16; \
+ MOVD offset+PTRACE_R17(reg), R17; \
+ MOVD offset+PTRACE_R19(reg), R19; \
+ MOVD offset+PTRACE_R20(reg), R20; \
+ MOVD offset+PTRACE_R21(reg), R21; \
+ MOVD offset+PTRACE_R22(reg), R22; \
+ MOVD offset+PTRACE_R23(reg), R23; \
+ MOVD offset+PTRACE_R24(reg), R24; \
+ MOVD offset+PTRACE_R25(reg), R25; \
+ MOVD offset+PTRACE_R26(reg), R26; \
+ MOVD offset+PTRACE_R27(reg), R27; \
+ MOVD offset+PTRACE_R28(reg), g; \
+ MOVD offset+PTRACE_R29(reg), R29; \
+ MOVD offset+PTRACE_R30(reg), R30;
+
+//NOP
+#define nop31Instructions() \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f;
+
+#define ESR_ELx_EC_UNKNOWN (0x00)
+#define ESR_ELx_EC_WFx (0x01)
+/* Unallocated EC: 0x02 */
+#define ESR_ELx_EC_CP15_32 (0x03)
+#define ESR_ELx_EC_CP15_64 (0x04)
+#define ESR_ELx_EC_CP14_MR (0x05)
+#define ESR_ELx_EC_CP14_LS (0x06)
+#define ESR_ELx_EC_FP_ASIMD (0x07)
+#define ESR_ELx_EC_CP10_ID (0x08) /* EL2 only */
+#define ESR_ELx_EC_PAC (0x09) /* EL2 and above */
+/* Unallocated EC: 0x0A - 0x0B */
+#define ESR_ELx_EC_CP14_64 (0x0C)
+/* Unallocated EC: 0x0d */
+#define ESR_ELx_EC_ILL (0x0E)
+/* Unallocated EC: 0x0F - 0x10 */
+#define ESR_ELx_EC_SVC32 (0x11)
+#define ESR_ELx_EC_HVC32 (0x12) /* EL2 only */
+#define ESR_ELx_EC_SMC32 (0x13) /* EL2 and above */
+/* Unallocated EC: 0x14 */
+#define ESR_ELx_EC_SVC64 (0x15)
+#define ESR_ELx_EC_HVC64 (0x16) /* EL2 and above */
+#define ESR_ELx_EC_SMC64 (0x17) /* EL2 and above */
+#define ESR_ELx_EC_SYS64 (0x18)
+#define ESR_ELx_EC_SVE (0x19)
+/* Unallocated EC: 0x1A - 0x1E */
+#define ESR_ELx_EC_IMP_DEF (0x1f) /* EL3 only */
+#define ESR_ELx_EC_IABT_LOW (0x20)
+#define ESR_ELx_EC_IABT_CUR (0x21)
+#define ESR_ELx_EC_PC_ALIGN (0x22)
+/* Unallocated EC: 0x23 */
+#define ESR_ELx_EC_DABT_LOW (0x24)
+#define ESR_ELx_EC_DABT_CUR (0x25)
+#define ESR_ELx_EC_SP_ALIGN (0x26)
+/* Unallocated EC: 0x27 */
+#define ESR_ELx_EC_FP_EXC32 (0x28)
+/* Unallocated EC: 0x29 - 0x2B */
+#define ESR_ELx_EC_FP_EXC64 (0x2C)
+/* Unallocated EC: 0x2D - 0x2E */
+#define ESR_ELx_EC_SERROR (0x2F)
+#define ESR_ELx_EC_BREAKPT_LOW (0x30)
+#define ESR_ELx_EC_BREAKPT_CUR (0x31)
+#define ESR_ELx_EC_SOFTSTP_LOW (0x32)
+#define ESR_ELx_EC_SOFTSTP_CUR (0x33)
+#define ESR_ELx_EC_WATCHPT_LOW (0x34)
+#define ESR_ELx_EC_WATCHPT_CUR (0x35)
+/* Unallocated EC: 0x36 - 0x37 */
+#define ESR_ELx_EC_BKPT32 (0x38)
+/* Unallocated EC: 0x39 */
+#define ESR_ELx_EC_VECTOR32 (0x3A) /* EL2 only */
+/* Unallocted EC: 0x3B */
+#define ESR_ELx_EC_BRK64 (0x3C)
+/* Unallocated EC: 0x3D - 0x3F */
+#define ESR_ELx_EC_MAX (0x3F)
+
+#define ESR_ELx_EC_SHIFT (26)
+#define ESR_ELx_EC_MASK (UL(0x3F) << ESR_ELx_EC_SHIFT)
+#define ESR_ELx_EC(esr) (((esr) & ESR_ELx_EC_MASK) >> ESR_ELx_EC_SHIFT)
+
+#define ESR_ELx_IL_SHIFT (25)
+#define ESR_ELx_IL (UL(1) << ESR_ELx_IL_SHIFT)
+#define ESR_ELx_ISS_MASK (ESR_ELx_IL - 1)
+
+/* ISS field definitions shared by different classes */
+#define ESR_ELx_WNR_SHIFT (6)
+#define ESR_ELx_WNR (UL(1) << ESR_ELx_WNR_SHIFT)
+
+/* Asynchronous Error Type */
+#define ESR_ELx_IDS_SHIFT (24)
+#define ESR_ELx_IDS (UL(1) << ESR_ELx_IDS_SHIFT)
+#define ESR_ELx_AET_SHIFT (10)
+#define ESR_ELx_AET (UL(0x7) << ESR_ELx_AET_SHIFT)
+
+#define ESR_ELx_AET_UC (UL(0) << ESR_ELx_AET_SHIFT)
+#define ESR_ELx_AET_UEU (UL(1) << ESR_ELx_AET_SHIFT)
+#define ESR_ELx_AET_UEO (UL(2) << ESR_ELx_AET_SHIFT)
+#define ESR_ELx_AET_UER (UL(3) << ESR_ELx_AET_SHIFT)
+#define ESR_ELx_AET_CE (UL(6) << ESR_ELx_AET_SHIFT)
+
+/* Shared ISS field definitions for Data/Instruction aborts */
+#define ESR_ELx_SET_SHIFT (11)
+#define ESR_ELx_SET_MASK (UL(3) << ESR_ELx_SET_SHIFT)
+#define ESR_ELx_FnV_SHIFT (10)
+#define ESR_ELx_FnV (UL(1) << ESR_ELx_FnV_SHIFT)
+#define ESR_ELx_EA_SHIFT (9)
+#define ESR_ELx_EA (UL(1) << ESR_ELx_EA_SHIFT)
+#define ESR_ELx_S1PTW_SHIFT (7)
+#define ESR_ELx_S1PTW (UL(1) << ESR_ELx_S1PTW_SHIFT)
+
+/* Shared ISS fault status code(IFSC/DFSC) for Data/Instruction aborts */
+#define ESR_ELx_FSC (0x3F)
+#define ESR_ELx_FSC_TYPE (0x3C)
+#define ESR_ELx_FSC_EXTABT (0x10)
+#define ESR_ELx_FSC_SERROR (0x11)
+#define ESR_ELx_FSC_ACCESS (0x08)
+#define ESR_ELx_FSC_FAULT (0x04)
+#define ESR_ELx_FSC_PERM (0x0C)
+
+/* ISS field definitions for Data Aborts */
+#define ESR_ELx_ISV_SHIFT (24)
+#define ESR_ELx_ISV (UL(1) << ESR_ELx_ISV_SHIFT)
+#define ESR_ELx_SAS_SHIFT (22)
+#define ESR_ELx_SAS (UL(3) << ESR_ELx_SAS_SHIFT)
+#define ESR_ELx_SSE_SHIFT (21)
+#define ESR_ELx_SSE (UL(1) << ESR_ELx_SSE_SHIFT)
+#define ESR_ELx_SRT_SHIFT (16)
+#define ESR_ELx_SRT_MASK (UL(0x1F) << ESR_ELx_SRT_SHIFT)
+#define ESR_ELx_SF_SHIFT (15)
+#define ESR_ELx_SF (UL(1) << ESR_ELx_SF_SHIFT)
+#define ESR_ELx_AR_SHIFT (14)
+#define ESR_ELx_AR (UL(1) << ESR_ELx_AR_SHIFT)
+#define ESR_ELx_CM_SHIFT (8)
+#define ESR_ELx_CM (UL(1) << ESR_ELx_CM_SHIFT)
+
+/* ISS field definitions for exceptions taken in to Hyp */
+#define ESR_ELx_CV (UL(1) << 24)
+#define ESR_ELx_COND_SHIFT (20)
+#define ESR_ELx_COND_MASK (UL(0xF) << ESR_ELx_COND_SHIFT)
+#define ESR_ELx_WFx_ISS_TI (UL(1) << 0)
+#define ESR_ELx_WFx_ISS_WFI (UL(0) << 0)
+#define ESR_ELx_WFx_ISS_WFE (UL(1) << 0)
+#define ESR_ELx_xVC_IMM_MASK ((1UL << 16) - 1)
+
+#define LOAD_KERNEL_ADDRESS(from, to) \
+ MOVD from, to; \
+ ORR $0xffff000000000000, to, to;
+
+// LOAD_KERNEL_STACK loads the kernel temporary stack.
+#define LOAD_KERNEL_STACK(from) \
+ LOAD_KERNEL_ADDRESS(CPU_SELF(from), RSV_REG); \
+ MOVD $CPU_STACK_TOP(RSV_REG), RSV_REG; \
+ MOVD RSV_REG, RSP; \
+ ISB $15; \
+ DSB $15;
+
+#define SWITCH_TO_APP_PAGETABLE(from) \
+ MOVD CPU_TTBR0_APP(from), RSV_REG; \
+ WORD $0xd5182012; \ // MSR R18, TTBR0_EL1
+ ISB $15; \
+ DSB $15;
+
+#define SWITCH_TO_KVM_PAGETABLE(from) \
+ MOVD CPU_TTBR0_KVM(from), RSV_REG; \
+ WORD $0xd5182012; \ // MSR R18, TTBR0_EL1
+ ISB $15; \
+ DSB $15;
+
+#define IRQ_ENABLE \
+ MSR $2, DAIFSet;
+
+#define IRQ_DISABLE \
+ MSR $2, DAIFClr;
+
+#define VFP_ENABLE \
+ MOVD $FPEN_ENABLE, R0; \
+ WORD $0xd5181040; \ //MSR R0, CPACR_EL1
+ ISB $15;
+
+#define VFP_DISABLE \
+ MOVD $0x0, R0; \
+ WORD $0xd5181040; \ //MSR R0, CPACR_EL1
+ ISB $15;
+
+#define KERNEL_ENTRY_FROM_EL0 \
+ SUB $16, RSP, RSP; \ // step1, save r18, r9 into kernel temporary stack.
+ STP (RSV_REG, RSV_REG_APP), 16*0(RSP); \
+ WORD $0xd538d092; \ //MRS TPIDR_EL1, R18, step2, switch user pagetable.
+ SWITCH_TO_KVM_PAGETABLE(RSV_REG); \
+ WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
+ MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP; \ // step3, load app context pointer.
+ REGISTERS_SAVE(RSV_REG_APP, 0); \ // step4, save app context.
+ MOVD RSV_REG_APP, R20; \
+ LDP 16*0(RSP), (RSV_REG, RSV_REG_APP); \
+ ADD $16, RSP, RSP; \
+ MOVD RSV_REG, PTRACE_R18(R20); \
+ MOVD RSV_REG_APP, PTRACE_R9(R20); \
+ MOVD R20, RSV_REG_APP; \
+ WORD $0xd5384003; \ // MRS SPSR_EL1, R3
+ MOVD R3, PTRACE_PSTATE(RSV_REG_APP); \
+ MRS ELR_EL1, R3; \
+ MOVD R3, PTRACE_PC(RSV_REG_APP); \
+ WORD $0xd5384103; \ // MRS SP_EL0, R3
+ MOVD R3, PTRACE_SP(RSV_REG_APP);
+
+#define KERNEL_ENTRY_FROM_EL1 \
+ WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
+ REGISTERS_SAVE(RSV_REG, CPU_REGISTERS); \ // save sentry context
+ MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG); \
+ WORD $0xd5384004; \ // MRS SPSR_EL1, R4
+ MOVD R4, CPU_REGISTERS+PTRACE_PSTATE(RSV_REG); \
+ MRS ELR_EL1, R4; \
+ MOVD R4, CPU_REGISTERS+PTRACE_PC(RSV_REG); \
+ MOVD RSP, R4; \
+ MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG);
+
+TEXT ·Halt(SB),NOSPLIT,$0
+ // clear bluepill.
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ CMP RSV_REG, R9
+ BNE mmio_exit
+ MOVD $0, CPU_REGISTERS+PTRACE_R9(RSV_REG)
+mmio_exit:
+ // Disable fpsimd.
+ WORD $0xd5381041 // MRS CPACR_EL1, R1
+ MOVD R1, CPU_LAZY_VFP(RSV_REG)
+ VFP_DISABLE
+
+ // MMIO_EXIT.
+ MOVD $0, R9
+ MOVD R0, 0xffff000000001000(R9)
+ B ·kernelExitToEl1(SB)
+
+TEXT ·Shutdown(SB),NOSPLIT,$0
+ // PSCI EVENT.
+ MOVD $0x84000009, R0
+ HVC $0
+
+// See kernel.go.
+TEXT ·Current(SB),NOSPLIT,$0-8
+ MOVD CPU_SELF(RSV_REG), R8
+ MOVD R8, ret+0(FP)
+ RET
+
+#define STACK_FRAME_SIZE 16
+
+TEXT ·kernelExitToEl0(SB),NOSPLIT,$0
+ ERET()
+
+TEXT ·kernelExitToEl1(SB),NOSPLIT,$0
+ ERET()
+
+TEXT ·Start(SB),NOSPLIT,$0
+ IRQ_DISABLE
+ MOVD R8, RSV_REG
+ ORR $0xffff000000000000, RSV_REG, RSV_REG
+ WORD $0xd518d092 //MSR R18, TPIDR_EL1
+
+ B ·kernelExitToEl1(SB)
+
+TEXT ·El1_sync_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·El1_irq_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·El1_fiq_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·El1_error_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·El1_sync(SB),NOSPLIT,$0
+ KERNEL_ENTRY_FROM_EL1
+ WORD $0xd5385219 // MRS ESR_EL1, R25
+ LSR $ESR_ELx_EC_SHIFT, R25, R24
+ CMP $ESR_ELx_EC_DABT_CUR, R24
+ BEQ el1_da
+ CMP $ESR_ELx_EC_IABT_CUR, R24
+ BEQ el1_ia
+ CMP $ESR_ELx_EC_SYS64, R24
+ BEQ el1_undef
+ CMP $ESR_ELx_EC_SP_ALIGN, R24
+ BEQ el1_sp_pc
+ CMP $ESR_ELx_EC_PC_ALIGN, R24
+ BEQ el1_sp_pc
+ CMP $ESR_ELx_EC_UNKNOWN, R24
+ BEQ el1_undef
+ CMP $ESR_ELx_EC_SVC64, R24
+ BEQ el1_svc
+ CMP $ESR_ELx_EC_BREAKPT_CUR, R24
+ BGE el1_dbg
+ CMP $ESR_ELx_EC_FP_ASIMD, R24
+ BEQ el1_fpsimd_acc
+ B el1_invalid
+
+el1_da:
+ B ·Halt(SB)
+
+el1_ia:
+ B ·Halt(SB)
+
+el1_sp_pc:
+ B ·Shutdown(SB)
+
+el1_undef:
+ B ·Shutdown(SB)
+
+el1_svc:
+ B ·Halt(SB)
+
+el1_dbg:
+ B ·Shutdown(SB)
+
+el1_fpsimd_acc:
+ VFP_ENABLE
+ B ·kernelExitToEl1(SB) // Resume.
+
+el1_invalid:
+ B ·Shutdown(SB)
+
+TEXT ·El1_irq(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·El1_fiq(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·El1_error(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·El0_sync(SB),NOSPLIT,$0
+ KERNEL_ENTRY_FROM_EL0
+ WORD $0xd5385219 // MRS ESR_EL1, R25
+ LSR $ESR_ELx_EC_SHIFT, R25, R24
+ CMP $ESR_ELx_EC_SVC64, R24
+ BEQ el0_svc
+ CMP $ESR_ELx_EC_DABT_LOW, R24
+ BEQ el0_da
+ CMP $ESR_ELx_EC_IABT_LOW, R24
+ BEQ el0_ia
+ CMP $ESR_ELx_EC_FP_ASIMD, R24
+ BEQ el0_fpsimd_acc
+ CMP $ESR_ELx_EC_SVE, R24
+ BEQ el0_sve_acc
+ CMP $ESR_ELx_EC_FP_EXC64, R24
+ BEQ el0_fpsimd_exc
+ CMP $ESR_ELx_EC_SP_ALIGN, R24
+ BEQ el0_sp_pc
+ CMP $ESR_ELx_EC_PC_ALIGN, R24
+ BEQ el0_sp_pc
+ CMP $ESR_ELx_EC_UNKNOWN, R24
+ BEQ el0_undef
+ CMP $ESR_ELx_EC_BREAKPT_LOW, R24
+ BGE el0_dbg
+ B el0_invalid
+
+el0_svc:
+ B ·Halt(SB)
+
+el0_da:
+ B ·Halt(SB)
+
+el0_ia:
+ B ·Shutdown(SB)
+
+el0_fpsimd_acc:
+ B ·Shutdown(SB)
+
+el0_sve_acc:
+ B ·Shutdown(SB)
+
+el0_fpsimd_exc:
+ B ·Shutdown(SB)
+
+el0_sp_pc:
+ B ·Shutdown(SB)
+
+el0_undef:
+ B ·Shutdown(SB)
+
+el0_dbg:
+ B ·Shutdown(SB)
+
+el0_invalid:
+ B ·Shutdown(SB)
+
+TEXT ·El0_irq(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·El0_fiq(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·El0_error(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·El0_sync_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·El0_irq_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·El0_fiq_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·El0_error_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·Vectors(SB),NOSPLIT,$0
+ B ·El1_sync_invalid(SB)
+ nop31Instructions()
+ B ·El1_irq_invalid(SB)
+ nop31Instructions()
+ B ·El1_fiq_invalid(SB)
+ nop31Instructions()
+ B ·El1_error_invalid(SB)
+ nop31Instructions()
+
+ B ·El1_sync(SB)
+ nop31Instructions()
+ B ·El1_irq(SB)
+ nop31Instructions()
+ B ·El1_fiq(SB)
+ nop31Instructions()
+ B ·El1_error(SB)
+ nop31Instructions()
+
+ B ·El0_sync(SB)
+ nop31Instructions()
+ B ·El0_irq(SB)
+ nop31Instructions()
+ B ·El0_fiq(SB)
+ nop31Instructions()
+ B ·El0_error(SB)
+ nop31Instructions()
+
+ B ·El0_sync_invalid(SB)
+ nop31Instructions()
+ B ·El0_irq_invalid(SB)
+ nop31Instructions()
+ B ·El0_fiq_invalid(SB)
+ nop31Instructions()
+ B ·El0_error_invalid(SB)
+ nop31Instructions()
+
+ WORD $0xd503201f //nop
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD
index 780bf9a66..42076fb04 100644
--- a/pkg/sentry/platform/ring0/gen_offsets/BUILD
+++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD
@@ -4,16 +4,24 @@ load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
go_template_instance(
- name = "defs_impl",
- out = "defs_impl.go",
+ name = "defs_impl_arm64",
+ out = "defs_impl_arm64.go",
package = "main",
- template = "//pkg/sentry/platform/ring0:defs",
+ template = "//pkg/sentry/platform/ring0:defs_arm64",
+)
+
+go_template_instance(
+ name = "defs_impl_amd64",
+ out = "defs_impl_amd64.go",
+ package = "main",
+ template = "//pkg/sentry/platform/ring0:defs_amd64",
)
go_binary(
name = "gen_offsets",
srcs = [
- "defs_impl.go",
+ "defs_impl_amd64.go",
+ "defs_impl_arm64.go",
"main.go",
],
visibility = ["//pkg/sentry/platform/ring0:__pkg__"],
diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go
new file mode 100644
index 000000000..ed82a131e
--- /dev/null
+++ b/pkg/sentry/platform/ring0/kernel_arm64.go
@@ -0,0 +1,58 @@
+// Copyright 2019 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ring0
+
+// init initializes architecture-specific state.
+func (k *Kernel) init(opts KernelOpts) {
+ // Save the root page tables.
+ k.PageTables = opts.PageTables
+}
+
+// init initializes architecture-specific state.
+func (c *CPU) init() {
+ // Set the kernel stack pointer(virtual address).
+ c.registers.Sp = uint64(c.StackTop())
+
+}
+
+// StackTop returns the kernel's stack address.
+//
+//go:nosplit
+func (c *CPU) StackTop() uint64 {
+ return uint64(kernelAddr(&c.stack[0])) + uint64(len(c.stack))
+}
+
+// IsCanonical indicates whether addr is canonical per the arm64 spec.
+//
+//go:nosplit
+func IsCanonical(addr uint64) bool {
+ return addr <= 0x0000ffffffffffff || addr > 0xffff000000000000
+}
+
+//go:nosplit
+func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
+ // Sanitize registers.
+ regs := switchOpts.Registers
+
+ regs.Pstate &= ^uint64(UserFlagsClear)
+ regs.Pstate |= UserFlagsSet
+ kernelExitToEl0()
+ vector = c.vecCode
+
+ // Perform the switch.
+ return
+}
diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go
new file mode 100644
index 000000000..8bcfe1032
--- /dev/null
+++ b/pkg/sentry/platform/ring0/lib_arm64.go
@@ -0,0 +1,39 @@
+// Copyright 2019 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ring0
+
+// CPACREL1 returns the value of the CPACR_EL1 register.
+func CPACREL1() (value uintptr)
+
+// FPCR returns the value of FPCR register.
+func FPCR() (value uintptr)
+
+// SetFPCR writes the FPCR value.
+func SetFPCR(value uintptr)
+
+// FPSR returns the value of FPSR register.
+func FPSR() (value uintptr)
+
+// SetFPSR writes the FPSR value.
+func SetFPSR(value uintptr)
+
+// SaveVRegs saves V0-V31 registers.
+// V0-V31: 32 128-bit registers for floating point and simd.
+func SaveVRegs(*byte)
+
+// LoadVRegs loads V0-V31 registers.
+func LoadVRegs(*byte)
diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s
new file mode 100644
index 000000000..1c9171004
--- /dev/null
+++ b/pkg/sentry/platform/ring0/lib_arm64.s
@@ -0,0 +1,118 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+TEXT ·CPACREL1(SB),NOSPLIT,$0-8
+ WORD $0xd5381041 // MRS CPACR_EL1, R1
+ MOVD R1, ret+0(FP)
+ RET
+
+TEXT ·FPCR(SB),NOSPLIT,$0-8
+ WORD $0xd53b4201 // MRS NZCV, R1
+ MOVD R1, ret+0(FP)
+ RET
+
+TEXT ·GetFPSR(SB),NOSPLIT,$0-8
+ WORD $0xd53b4421 // MRS FPSR, R1
+ MOVD R1, ret+0(FP)
+ RET
+
+TEXT ·FPCR(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R1
+ WORD $0xd51b4201 // MSR R1, NZCV
+ RET
+
+TEXT ·SetFPSR(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R1
+ WORD $0xd51b4421 // MSR R1, FPSR
+ RET
+
+TEXT ·SaveVRegs(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R0
+
+ // Skip aarch64_ctx, fpsr, fpcr.
+ FMOVD F0, 16*1(R0)
+ FMOVD F1, 16*2(R0)
+ FMOVD F2, 16*3(R0)
+ FMOVD F3, 16*4(R0)
+ FMOVD F4, 16*5(R0)
+ FMOVD F5, 16*6(R0)
+ FMOVD F6, 16*7(R0)
+ FMOVD F7, 16*8(R0)
+ FMOVD F8, 16*9(R0)
+ FMOVD F9, 16*10(R0)
+ FMOVD F10, 16*11(R0)
+ FMOVD F11, 16*12(R0)
+ FMOVD F12, 16*13(R0)
+ FMOVD F13, 16*14(R0)
+ FMOVD F14, 16*15(R0)
+ FMOVD F15, 16*16(R0)
+ FMOVD F16, 16*17(R0)
+ FMOVD F17, 16*18(R0)
+ FMOVD F18, 16*19(R0)
+ FMOVD F19, 16*20(R0)
+ FMOVD F20, 16*21(R0)
+ FMOVD F21, 16*22(R0)
+ FMOVD F22, 16*23(R0)
+ FMOVD F23, 16*24(R0)
+ FMOVD F24, 16*25(R0)
+ FMOVD F25, 16*26(R0)
+ FMOVD F26, 16*27(R0)
+ FMOVD F27, 16*28(R0)
+ FMOVD F28, 16*29(R0)
+ FMOVD F29, 16*30(R0)
+ FMOVD F30, 16*31(R0)
+ FMOVD F31, 16*32(R0)
+ ISB $15
+
+ RET
+
+TEXT ·LoadVRegs(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R0
+
+ // Skip aarch64_ctx, fpsr, fpcr.
+ FMOVD 16*1(R0), F0
+ FMOVD 16*2(R0), F1
+ FMOVD 16*3(R0), F2
+ FMOVD 16*4(R0), F3
+ FMOVD 16*5(R0), F4
+ FMOVD 16*6(R0), F5
+ FMOVD 16*7(R0), F6
+ FMOVD 16*8(R0), F7
+ FMOVD 16*9(R0), F8
+ FMOVD 16*10(R0), F9
+ FMOVD 16*11(R0), F10
+ FMOVD 16*12(R0), F11
+ FMOVD 16*13(R0), F12
+ FMOVD 16*14(R0), F13
+ FMOVD 16*15(R0), F14
+ FMOVD 16*16(R0), F15
+ FMOVD 16*17(R0), F16
+ FMOVD 16*18(R0), F17
+ FMOVD 16*19(R0), F18
+ FMOVD 16*20(R0), F19
+ FMOVD 16*21(R0), F20
+ FMOVD 16*22(R0), F21
+ FMOVD 16*23(R0), F22
+ FMOVD 16*24(R0), F23
+ FMOVD 16*25(R0), F24
+ FMOVD 16*26(R0), F25
+ FMOVD 16*27(R0), F26
+ FMOVD 16*28(R0), F27
+ FMOVD 16*29(R0), F28
+ FMOVD 16*30(R0), F29
+ FMOVD 16*31(R0), F30
+ FMOVD 16*32(R0), F31
+ ISB $15
+
+ RET
diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go
new file mode 100644
index 000000000..cd2a65f97
--- /dev/null
+++ b/pkg/sentry/platform/ring0/offsets_arm64.go
@@ -0,0 +1,125 @@
+// Copyright 2019 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ring0
+
+import (
+ "fmt"
+ "io"
+ "reflect"
+ "syscall"
+)
+
+// Emit prints architecture-specific offsets.
+func Emit(w io.Writer) {
+ fmt.Fprintf(w, "// Automatically generated, do not edit.\n")
+
+ c := &CPU{}
+ fmt.Fprintf(w, "\n// CPU offsets.\n")
+ fmt.Fprintf(w, "#define CPU_SELF 0x%02x\n", reflect.ValueOf(&c.self).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_REGISTERS 0x%02x\n", reflect.ValueOf(&c.registers).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_STACK_TOP 0x%02x\n", reflect.ValueOf(&c.stack[0]).Pointer()-reflect.ValueOf(c).Pointer()+uintptr(len(c.stack)))
+ fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_FAULT_ADDR 0x%02x\n", reflect.ValueOf(&c.faultAddr).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_TTBR0_KVM 0x%02x\n", reflect.ValueOf(&c.ttbr0Kvm).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_TTBR0_APP 0x%02x\n", reflect.ValueOf(&c.ttbr0App).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_VECTOR_CODE 0x%02x\n", reflect.ValueOf(&c.vecCode).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_APP_ADDR 0x%02x\n", reflect.ValueOf(&c.appAddr).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_LAZY_VFP 0x%02x\n", reflect.ValueOf(&c.lazyVFP).Pointer()-reflect.ValueOf(c).Pointer())
+
+ fmt.Fprintf(w, "\n// Bits.\n")
+ fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet)
+
+ fmt.Fprintf(w, "\n// Vectors.\n")
+ fmt.Fprintf(w, "#define El1SyncInvalid 0x%02x\n", El1SyncInvalid)
+ fmt.Fprintf(w, "#define El1IrqInvalid 0x%02x\n", El1IrqInvalid)
+ fmt.Fprintf(w, "#define El1FiqInvalid 0x%02x\n", El1FiqInvalid)
+ fmt.Fprintf(w, "#define El1ErrorInvalid 0x%02x\n", El1ErrorInvalid)
+
+ fmt.Fprintf(w, "#define El1Sync 0x%02x\n", El1Sync)
+ fmt.Fprintf(w, "#define El1Irq 0x%02x\n", El1Irq)
+ fmt.Fprintf(w, "#define El1Fiq 0x%02x\n", El1Fiq)
+ fmt.Fprintf(w, "#define El1Error 0x%02x\n", El1Error)
+
+ fmt.Fprintf(w, "#define El0Sync 0x%02x\n", El0Sync)
+ fmt.Fprintf(w, "#define El0Irq 0x%02x\n", El0Irq)
+ fmt.Fprintf(w, "#define El0Fiq 0x%02x\n", El0Fiq)
+ fmt.Fprintf(w, "#define El0Error 0x%02x\n", El0Error)
+
+ fmt.Fprintf(w, "#define El0Sync_invalid 0x%02x\n", El0Sync_invalid)
+ fmt.Fprintf(w, "#define El0Irq_invalid 0x%02x\n", El0Irq_invalid)
+ fmt.Fprintf(w, "#define El0Fiq_invalid 0x%02x\n", El0Fiq_invalid)
+ fmt.Fprintf(w, "#define El0Error_invalid 0x%02x\n", El0Error_invalid)
+
+ fmt.Fprintf(w, "#define El1Sync_da 0x%02x\n", El1Sync_da)
+ fmt.Fprintf(w, "#define El1Sync_ia 0x%02x\n", El1Sync_ia)
+ fmt.Fprintf(w, "#define El1Sync_sp_pc 0x%02x\n", El1Sync_sp_pc)
+ fmt.Fprintf(w, "#define El1Sync_undef 0x%02x\n", El1Sync_undef)
+ fmt.Fprintf(w, "#define El1Sync_dbg 0x%02x\n", El1Sync_dbg)
+ fmt.Fprintf(w, "#define El1Sync_inv 0x%02x\n", El1Sync_inv)
+
+ fmt.Fprintf(w, "#define El0Sync_svc 0x%02x\n", El0Sync_svc)
+ fmt.Fprintf(w, "#define El0Sync_da 0x%02x\n", El0Sync_da)
+ fmt.Fprintf(w, "#define El0Sync_ia 0x%02x\n", El0Sync_ia)
+ fmt.Fprintf(w, "#define El0Sync_fpsimd_acc 0x%02x\n", El0Sync_fpsimd_acc)
+ fmt.Fprintf(w, "#define El0Sync_sve_acc 0x%02x\n", El0Sync_sve_acc)
+ fmt.Fprintf(w, "#define El0Sync_sys 0x%02x\n", El0Sync_sys)
+ fmt.Fprintf(w, "#define El0Sync_sp_pc 0x%02x\n", El0Sync_sp_pc)
+ fmt.Fprintf(w, "#define El0Sync_undef 0x%02x\n", El0Sync_undef)
+ fmt.Fprintf(w, "#define El0Sync_dbg 0x%02x\n", El0Sync_dbg)
+ fmt.Fprintf(w, "#define El0Sync_inv 0x%02x\n", El0Sync_inv)
+
+ fmt.Fprintf(w, "#define PageFault 0x%02x\n", PageFault)
+ fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall)
+
+ p := &syscall.PtraceRegs{}
+ fmt.Fprintf(w, "\n// Ptrace registers.\n")
+ fmt.Fprintf(w, "#define PTRACE_R0 0x%02x\n", reflect.ValueOf(&p.Regs[0]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R1 0x%02x\n", reflect.ValueOf(&p.Regs[1]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R2 0x%02x\n", reflect.ValueOf(&p.Regs[2]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R3 0x%02x\n", reflect.ValueOf(&p.Regs[3]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R4 0x%02x\n", reflect.ValueOf(&p.Regs[4]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R5 0x%02x\n", reflect.ValueOf(&p.Regs[5]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R6 0x%02x\n", reflect.ValueOf(&p.Regs[6]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R7 0x%02x\n", reflect.ValueOf(&p.Regs[7]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R8 0x%02x\n", reflect.ValueOf(&p.Regs[8]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R9 0x%02x\n", reflect.ValueOf(&p.Regs[9]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R10 0x%02x\n", reflect.ValueOf(&p.Regs[10]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R11 0x%02x\n", reflect.ValueOf(&p.Regs[11]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R12 0x%02x\n", reflect.ValueOf(&p.Regs[12]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R13 0x%02x\n", reflect.ValueOf(&p.Regs[13]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R14 0x%02x\n", reflect.ValueOf(&p.Regs[14]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R15 0x%02x\n", reflect.ValueOf(&p.Regs[15]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R16 0x%02x\n", reflect.ValueOf(&p.Regs[16]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R17 0x%02x\n", reflect.ValueOf(&p.Regs[17]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R18 0x%02x\n", reflect.ValueOf(&p.Regs[18]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R19 0x%02x\n", reflect.ValueOf(&p.Regs[19]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R20 0x%02x\n", reflect.ValueOf(&p.Regs[20]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R21 0x%02x\n", reflect.ValueOf(&p.Regs[21]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R22 0x%02x\n", reflect.ValueOf(&p.Regs[22]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R23 0x%02x\n", reflect.ValueOf(&p.Regs[23]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R24 0x%02x\n", reflect.ValueOf(&p.Regs[24]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R25 0x%02x\n", reflect.ValueOf(&p.Regs[25]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R26 0x%02x\n", reflect.ValueOf(&p.Regs[26]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R27 0x%02x\n", reflect.ValueOf(&p.Regs[27]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R28 0x%02x\n", reflect.ValueOf(&p.Regs[28]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R29 0x%02x\n", reflect.ValueOf(&p.Regs[29]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R30 0x%02x\n", reflect.ValueOf(&p.Regs[30]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_SP 0x%02x\n", reflect.ValueOf(&p.Sp).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_PC 0x%02x\n", reflect.ValueOf(&p.Pc).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_PSTATE 0x%02x\n", reflect.ValueOf(&p.Pstate).Pointer()-reflect.ValueOf(p).Pointer())
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD
index 934a90378..e2e15ba5c 100644
--- a/pkg/sentry/platform/ring0/pagetables/BUILD
+++ b/pkg/sentry/platform/ring0/pagetables/BUILD
@@ -1,14 +1,17 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
package(licenses = ["notice"])
+config_setting(
+ name = "aarch64",
+ constraint_values = ["@bazel_tools//platforms:aarch64"],
+)
+
go_template(
name = "generic_walker",
- srcs = [
- "walker_amd64.go",
- ],
+ srcs = ["walker_amd64.go"],
opt_types = [
"Visitor",
],
@@ -76,9 +79,13 @@ go_library(
"allocator.go",
"allocator_unsafe.go",
"pagetables.go",
+ "pagetables_aarch64.go",
"pagetables_amd64.go",
+ "pagetables_arm64.go",
"pagetables_x86.go",
"pcids_x86.go",
+ "walker_amd64.go",
+ "walker_arm64.go",
"walker_empty.go",
"walker_lookup.go",
"walker_map.go",
@@ -97,6 +104,7 @@ go_test(
size = "small",
srcs = [
"pagetables_amd64_test.go",
+ "pagetables_arm64_test.go",
"pagetables_test.go",
"walker_check.go",
],
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go
index 904f1a6de..30c64a372 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go
@@ -48,15 +48,6 @@ func New(a Allocator) *PageTables {
return p
}
-// Init initializes a set of PageTables.
-//
-//go:nosplit
-func (p *PageTables) Init(allocator Allocator) {
- p.Allocator = allocator
- p.root = p.Allocator.NewPTEs()
- p.rootPhysical = p.Allocator.PhysicalFor(p.root)
-}
-
// mapVisitor is used for map.
type mapVisitor struct {
target uintptr // Input.
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
new file mode 100644
index 000000000..e78424766
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
@@ -0,0 +1,212 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package pagetables
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+)
+
+// archPageTables is architecture-specific data.
+type archPageTables struct {
+ // root is the pagetable root for kernel space.
+ root *PTEs
+
+ // rootPhysical is the cached physical address of the root.
+ //
+ // This is saved only to prevent constant translation.
+ rootPhysical uintptr
+
+ asid uint16
+}
+
+// TTBR0_EL1 returns the translation table base register 0.
+//
+//go:nosplit
+func (p *PageTables) TTBR0_EL1(noFlush bool, asid uint16) uint64 {
+ return uint64(p.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset
+}
+
+// TTBR1_EL1 returns the translation table base register 1.
+//
+//go:nosplit
+func (p *PageTables) TTBR1_EL1(noFlush bool, asid uint16) uint64 {
+ return uint64(p.archPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset
+}
+
+// Bits in page table entries.
+const (
+ typeTable = 0x3 << 0
+ typeSect = 0x1 << 0
+ typePage = 0x3 << 0
+ pteValid = 0x1 << 0
+ pteTableBit = 0x1 << 1
+ pteTypeMask = 0x3 << 0
+ present = pteValid | pteTableBit
+ user = 0x1 << 6 /* AP[1] */
+ readOnly = 0x1 << 7 /* AP[2] */
+ accessed = 0x1 << 10
+ dbm = 0x1 << 51
+ writable = dbm
+ cont = 0x1 << 52
+ pxn = 0x1 << 53
+ xn = 0x1 << 54
+ dirty = 0x1 << 55
+ nG = 0x1 << 11
+ shared = 0x3 << 8
+)
+
+const (
+ mtNormal = 0x4 << 2
+)
+
+const (
+ executeDisable = xn
+ optionMask = 0xfff | 0xfff<<48
+ protDefault = accessed | shared | mtNormal
+)
+
+// MapOpts are x86 options.
+type MapOpts struct {
+ // AccessType defines permissions.
+ AccessType usermem.AccessType
+
+ // Global indicates the page is globally accessible.
+ Global bool
+
+ // User indicates the page is a user page.
+ User bool
+}
+
+// PTE is a page table entry.
+type PTE uintptr
+
+// Clear clears this PTE, including sect page information.
+//
+//go:nosplit
+func (p *PTE) Clear() {
+ atomic.StoreUintptr((*uintptr)(p), 0)
+}
+
+// Valid returns true iff this entry is valid.
+//
+//go:nosplit
+func (p *PTE) Valid() bool {
+ return atomic.LoadUintptr((*uintptr)(p))&present != 0
+}
+
+// Opts returns the PTE options.
+//
+// These are all options except Valid and Sect.
+//
+//go:nosplit
+func (p *PTE) Opts() MapOpts {
+ v := atomic.LoadUintptr((*uintptr)(p))
+
+ return MapOpts{
+ AccessType: usermem.AccessType{
+ Read: true,
+ Write: v&readOnly == 0,
+ Execute: v&xn == 0,
+ },
+ Global: v&nG == 0,
+ User: v&user != 0,
+ }
+}
+
+// SetSect sets this page as a sect page.
+//
+// The page must not be valid or a panic will result.
+//
+//go:nosplit
+func (p *PTE) SetSect() {
+ if p.Valid() {
+ // This is not allowed.
+ panic("SetSect called on valid page!")
+ }
+ atomic.StoreUintptr((*uintptr)(p), typeSect)
+}
+
+// IsSect returns true iff this page is a sect page.
+//
+//go:nosplit
+func (p *PTE) IsSect() bool {
+ return atomic.LoadUintptr((*uintptr)(p))&pteTypeMask == typeSect
+}
+
+// Set sets this PTE value.
+//
+// This does not change the sect page property.
+//
+//go:nosplit
+func (p *PTE) Set(addr uintptr, opts MapOpts) {
+ if !opts.AccessType.Any() {
+ p.Clear()
+ return
+ }
+ v := (addr &^ optionMask) | protDefault | nG | readOnly
+
+ if p.IsSect() {
+ // Note that this is inherited from the previous instance. Set
+ // does not change the value of Sect. See above.
+ v |= typeSect
+ } else {
+ v |= typePage
+ }
+
+ if opts.Global {
+ v = v &^ nG
+ }
+
+ if opts.AccessType.Execute {
+ v = v &^ executeDisable
+ } else {
+ v |= executeDisable
+ }
+ if opts.AccessType.Write {
+ v = v &^ readOnly
+ }
+
+ if opts.User {
+ v |= user
+ } else {
+ v = v &^ user
+ }
+ atomic.StoreUintptr((*uintptr)(p), v)
+}
+
+// setPageTable sets this PTE value and forces the write bit and sect bit to
+// be cleared. This is used explicitly for breaking sect pages.
+//
+//go:nosplit
+func (p *PTE) setPageTable(pt *PageTables, ptes *PTEs) {
+ addr := pt.Allocator.PhysicalFor(ptes)
+ if addr&^optionMask != addr {
+ // This should never happen.
+ panic("unaligned physical address!")
+ }
+ v := addr | typeTable | protDefault
+ atomic.StoreUintptr((*uintptr)(p), v)
+}
+
+// Address extracts the address. This should only be used if Valid returns true.
+//
+//go:nosplit
+func (p *PTE) Address() uintptr {
+ return atomic.LoadUintptr((*uintptr)(p)) &^ optionMask
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
index 7aa6c524e..0c153cf8c 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
@@ -41,5 +41,14 @@ const (
entriesPerPage = 512
)
+// Init initializes a set of PageTables.
+//
+//go:nosplit
+func (p *PageTables) Init(allocator Allocator) {
+ p.Allocator = allocator
+ p.root = p.Allocator.NewPTEs()
+ p.rootPhysical = p.Allocator.PhysicalFor(p.root)
+}
+
// PTEs is a collection of entries.
type PTEs [entriesPerPage]PTE
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
new file mode 100644
index 000000000..1a49f12a2
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
@@ -0,0 +1,57 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pagetables
+
+// Address constraints.
+//
+// The lowerTop and upperBottom currently apply to four-level pagetables;
+// additional refactoring would be necessary to support five-level pagetables.
+const (
+ lowerTop = 0x0000ffffffffffff
+ upperBottom = 0xffff000000000000
+ pteShift = 12
+ pmdShift = 21
+ pudShift = 30
+ pgdShift = 39
+
+ pteMask = 0x1ff << pteShift
+ pmdMask = 0x1ff << pmdShift
+ pudMask = 0x1ff << pudShift
+ pgdMask = 0x1ff << pgdShift
+
+ pteSize = 1 << pteShift
+ pmdSize = 1 << pmdShift
+ pudSize = 1 << pudShift
+ pgdSize = 1 << pgdShift
+
+ ttbrASIDOffset = 55
+ ttbrASIDMask = 0xff
+
+ entriesPerPage = 512
+)
+
+// Init initializes a set of PageTables.
+//
+//go:nosplit
+func (p *PageTables) Init(allocator Allocator) {
+ p.Allocator = allocator
+ p.root = p.Allocator.NewPTEs()
+ p.rootPhysical = p.Allocator.PhysicalFor(p.root)
+ p.archPageTables.root = p.Allocator.NewPTEs()
+ p.archPageTables.rootPhysical = p.Allocator.PhysicalFor(p.archPageTables.root)
+}
+
+// PTEs is a collection of entries.
+type PTEs [entriesPerPage]PTE
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go
new file mode 100644
index 000000000..254116233
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go
@@ -0,0 +1,80 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package pagetables
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+)
+
+func Test2MAnd4K(t *testing.T) {
+ pt := New(NewRuntimeAllocator())
+
+ // Map a small page and a huge page.
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: true}, pteSize*42)
+ pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: true}, pmdSize*47)
+
+ pt.Map(0xffff000000400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: false}, pteSize*42)
+ pt.Map(0xffffff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: false}, pmdSize*47)
+
+ checkMappings(t, pt, []mapping{
+ {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: true}},
+ {0x0000ff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read, User: true}},
+ {0xffff000000400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: false}},
+ {0xffffff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read, User: false}},
+ })
+}
+
+func Test1GAnd4K(t *testing.T) {
+ pt := New(NewRuntimeAllocator())
+
+ // Map a small page and a super page.
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: true}, pteSize*42)
+ pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: usermem.Read, User: true}, pudSize*47)
+
+ checkMappings(t, pt, []mapping{
+ {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: true}},
+ {0x0000ff0000000000, pudSize, pudSize * 47, MapOpts{AccessType: usermem.Read, User: true}},
+ })
+}
+
+func TestSplit1GPage(t *testing.T) {
+ pt := New(NewRuntimeAllocator())
+
+ // Map a super page and knock out the middle.
+ pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: usermem.Read, User: true}, pudSize*42)
+ pt.Unmap(usermem.Addr(0x0000ff0000000000+pteSize), pudSize-(2*pteSize))
+
+ checkMappings(t, pt, []mapping{
+ {0x0000ff0000000000, pteSize, pudSize * 42, MapOpts{AccessType: usermem.Read, User: true}},
+ {0x0000ff0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: usermem.Read, User: true}},
+ })
+}
+
+func TestSplit2MPage(t *testing.T) {
+ pt := New(NewRuntimeAllocator())
+
+ // Map a huge page and knock out the middle.
+ pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: true}, pmdSize*42)
+ pt.Unmap(usermem.Addr(0x0000ff0000000000+pteSize), pmdSize-(2*pteSize))
+
+ checkMappings(t, pt, []mapping{
+ {0x0000ff0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: usermem.Read, User: true}},
+ {0x0000ff0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: usermem.Read, User: true}},
+ })
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go
new file mode 100644
index 000000000..c261d393a
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go
@@ -0,0 +1,314 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package pagetables
+
+// Visitor is a generic type.
+type Visitor interface {
+ // visit is called on each PTE.
+ visit(start uintptr, pte *PTE, align uintptr)
+
+ // requiresAlloc indicates that new entries should be allocated within
+ // the walked range.
+ requiresAlloc() bool
+
+ // requiresSplit indicates that entries in the given range should be
+ // split if they are huge or jumbo pages.
+ requiresSplit() bool
+}
+
+// Walker walks page tables.
+type Walker struct {
+ // pageTables are the tables to walk.
+ pageTables *PageTables
+
+ // Visitor is the set of arguments.
+ visitor Visitor
+}
+
+// iterateRange iterates over all appropriate levels of page tables for the given range.
+//
+// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The
+// exception is sect pages. If a valid sect page (huge or jumbo) cannot be
+// installed, then the walk will continue to individual entries.
+//
+// This algorithm will attempt to maximize the use of sect pages whenever
+// possible. Whether a sect page is provided will be clear through the range
+// provided in the callback.
+//
+// Note that if requiresAlloc is true, then no gaps will be present. However,
+// if alloc is not set, then the iteration will likely be full of gaps.
+//
+// Note that this function should generally be avoided in favor of Map, Unmap,
+// etc. when not necessary.
+//
+// Precondition: start must be page-aligned.
+//
+// Precondition: start must be less than end.
+//
+// Precondition: If requiresAlloc is true, then start and end should not span
+// non-canonical ranges. If they do, a panic will result.
+//
+//go:nosplit
+func (w *Walker) iterateRange(start, end uintptr) {
+ if start%pteSize != 0 {
+ panic("unaligned start")
+ }
+ if end < start {
+ panic("start > end")
+ }
+ if start < lowerTop {
+ if end <= lowerTop {
+ w.iterateRangeCanonical(start, end)
+ } else if end > lowerTop && end <= upperBottom {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(start, lowerTop)
+ } else {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(start, lowerTop)
+ w.iterateRangeCanonical(upperBottom, end)
+ }
+ } else if start < upperBottom {
+ if end <= upperBottom {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ } else {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(upperBottom, end)
+ }
+ } else {
+ w.iterateRangeCanonical(start, end)
+ }
+}
+
+// next returns the next address quantized by the given size.
+//
+//go:nosplit
+func next(start uintptr, size uintptr) uintptr {
+ start &= ^(size - 1)
+ start += size
+ return start
+}
+
+// iterateRangeCanonical walks a canonical range.
+//
+//go:nosplit
+func (w *Walker) iterateRangeCanonical(start, end uintptr) {
+ pgdEntryIndex := w.pageTables.root
+ if start >= upperBottom {
+ pgdEntryIndex = w.pageTables.archPageTables.root
+ }
+
+ for pgdIndex := (uint16((start & pgdMask) >> pgdShift)); start < end && pgdIndex < entriesPerPage; pgdIndex++ {
+ var (
+ pgdEntry = &pgdEntryIndex[pgdIndex]
+ pudEntries *PTEs
+ )
+ if !pgdEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+ // Skip over this entry.
+ start = next(start, pgdSize)
+ continue
+ }
+
+ // Allocate a new pgd.
+ pudEntries = w.pageTables.Allocator.NewPTEs()
+ pgdEntry.setPageTable(w.pageTables, pudEntries)
+ } else {
+ pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address())
+ }
+
+ // Map the next level.
+ clearPUDEntries := uint16(0)
+
+ for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ {
+ var (
+ pudEntry = &pudEntries[pudIndex]
+ pmdEntries *PTEs
+ )
+ if !pudEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+ // Skip over this entry.
+ clearPUDEntries++
+ start = next(start, pudSize)
+ continue
+ }
+
+ // This level has 1-GB sect pages. Is this
+ // entire region at least as large as a single
+ // PUD entry? If so, we can skip allocating a
+ // new page for the pmd.
+ if start&(pudSize-1) == 0 && end-start >= pudSize {
+ pudEntry.SetSect()
+ w.visitor.visit(uintptr(start), pudEntry, pudSize-1)
+ if pudEntry.Valid() {
+ start = next(start, pudSize)
+ continue
+ }
+ }
+
+ // Allocate a new pud.
+ pmdEntries = w.pageTables.Allocator.NewPTEs()
+ pudEntry.setPageTable(w.pageTables, pmdEntries)
+
+ } else if pudEntry.IsSect() {
+ // Does this page need to be split?
+ if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < next(start, pudSize)) {
+ // Install the relevant entries.
+ pmdEntries = w.pageTables.Allocator.NewPTEs()
+ for index := uint16(0); index < entriesPerPage; index++ {
+ pmdEntries[index].SetSect()
+ pmdEntries[index].Set(
+ pudEntry.Address()+(pmdSize*uintptr(index)),
+ pudEntry.Opts())
+ }
+ pudEntry.setPageTable(w.pageTables, pmdEntries)
+ } else {
+ // A sect page to be checked directly.
+ w.visitor.visit(uintptr(start), pudEntry, pudSize-1)
+
+ // Might have been cleared.
+ if !pudEntry.Valid() {
+ clearPUDEntries++
+ }
+
+ // Note that the sect page was changed.
+ start = next(start, pudSize)
+ continue
+ }
+
+ } else {
+ pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address())
+ }
+
+ // Map the next level, since this is valid.
+ clearPMDEntries := uint16(0)
+
+ for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ {
+ var (
+ pmdEntry = &pmdEntries[pmdIndex]
+ pteEntries *PTEs
+ )
+ if !pmdEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+ // Skip over this entry.
+ clearPMDEntries++
+ start = next(start, pmdSize)
+ continue
+ }
+
+ // This level has 2-MB huge pages. If this
+ // region is contined in a single PMD entry?
+ // As above, we can skip allocating a new page.
+ if start&(pmdSize-1) == 0 && end-start >= pmdSize {
+ pmdEntry.SetSect()
+ w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1)
+ if pmdEntry.Valid() {
+ start = next(start, pmdSize)
+ continue
+ }
+ }
+
+ // Allocate a new pmd.
+ pteEntries = w.pageTables.Allocator.NewPTEs()
+ pmdEntry.setPageTable(w.pageTables, pteEntries)
+
+ } else if pmdEntry.IsSect() {
+ // Does this page need to be split?
+ if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < next(start, pmdSize)) {
+ // Install the relevant entries.
+ pteEntries = w.pageTables.Allocator.NewPTEs()
+ for index := uint16(0); index < entriesPerPage; index++ {
+ pteEntries[index].Set(
+ pmdEntry.Address()+(pteSize*uintptr(index)),
+ pmdEntry.Opts())
+ }
+ pmdEntry.setPageTable(w.pageTables, pteEntries)
+ } else {
+ // A huge page to be checked directly.
+ w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1)
+
+ // Might have been cleared.
+ if !pmdEntry.Valid() {
+ clearPMDEntries++
+ }
+
+ // Note that the huge page was changed.
+ start = next(start, pmdSize)
+ continue
+ }
+
+ } else {
+ pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address())
+ }
+
+ // Map the next level, since this is valid.
+ clearPTEEntries := uint16(0)
+
+ for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ {
+ var (
+ pteEntry = &pteEntries[pteIndex]
+ )
+ if !pteEntry.Valid() && !w.visitor.requiresAlloc() {
+ clearPTEEntries++
+ start += pteSize
+ continue
+ }
+
+ // At this point, we are guaranteed that start%pteSize == 0.
+ w.visitor.visit(uintptr(start), pteEntry, pteSize-1)
+ if !pteEntry.Valid() {
+ if w.visitor.requiresAlloc() {
+ panic("PTE not set after iteration with requiresAlloc!")
+ }
+ clearPTEEntries++
+ }
+
+ // Note that the pte was changed.
+ start += pteSize
+ continue
+ }
+
+ // Check if we no longer need this page.
+ if clearPTEEntries == entriesPerPage {
+ pmdEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pteEntries)
+ clearPMDEntries++
+ }
+ }
+
+ // Check if we no longer need this page.
+ if clearPMDEntries == entriesPerPage {
+ pudEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pmdEntries)
+ clearPUDEntries++
+ }
+ }
+
+ // Check if we no longer need this page.
+ if clearPUDEntries == entriesPerPage {
+ pgdEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pudEntries)
+ }
+ }
+}
diff --git a/pkg/sentry/sighandling/sighandling.go b/pkg/sentry/sighandling/sighandling.go
index 2f65db70b..ba1f9043d 100644
--- a/pkg/sentry/sighandling/sighandling.go
+++ b/pkg/sentry/sighandling/sighandling.go
@@ -16,7 +16,6 @@
package sighandling
import (
- "fmt"
"os"
"os/signal"
"reflect"
@@ -31,37 +30,25 @@ const numSignals = 32
// handleSignals listens for incoming signals and calls the given handler
// function.
//
-// It starts when the start channel is closed, stops when the stop channel
-// is closed, and closes done once it will no longer deliver signals to k.
-func handleSignals(sigchans []chan os.Signal, handler func(linux.Signal), start, stop, done chan struct{}) {
+// It stops when the stop channel is closed. The done channel is closed once it
+// will no longer deliver signals to k.
+func handleSignals(sigchans []chan os.Signal, handler func(linux.Signal), stop, done chan struct{}) {
// Build a select case.
- sc := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(start)}}
+ sc := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(stop)}}
for _, sigchan := range sigchans {
sc = append(sc, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sigchan)})
}
- started := false
for {
// Wait for a notification.
index, _, ok := reflect.Select(sc)
- // Was it the start / stop channel?
+ // Was it the stop channel?
if index == 0 {
if !ok {
- if !started {
- // start channel; start forwarding and
- // swap this case for the stop channel
- // to select stop requests.
- started = true
- sc[0] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(stop)}
- } else {
- // stop channel; stop forwarding and
- // clear this case so it is never
- // selected again.
- started = false
- close(done)
- sc[0].Chan = reflect.Value{}
- }
+ // Stop forwarding and notify that it's done.
+ close(done)
+ return
}
continue
}
@@ -73,44 +60,17 @@ func handleSignals(sigchans []chan os.Signal, handler func(linux.Signal), start,
// Otherwise, it was a signal on channel N. Index 0 represents the stop
// channel, so index N represents the channel for signal N.
- signal := linux.Signal(index)
-
- if !started {
- // Kernel cannot receive signals, either because it is
- // not ready yet or is shutting down.
- //
- // Kill ourselves if this signal would have killed the
- // process before PrepareForwarding was called. i.e., all
- // _SigKill signals; see Go
- // src/runtime/sigtab_linux_generic.go.
- //
- // Otherwise ignore the signal.
- //
- // TODO(b/114489875): Drop in Go 1.12, which uses tgkill
- // in runtime.raise.
- switch signal {
- case linux.SIGHUP, linux.SIGINT, linux.SIGTERM:
- dieFromSignal(signal)
- panic(fmt.Sprintf("Failed to die from signal %d", signal))
- default:
- continue
- }
- }
-
- // Pass the signal to the handler.
- handler(signal)
+ handler(linux.Signal(index))
}
}
-// PrepareHandler ensures that synchronous signals are passed to the given
-// handler function and returns a callback that starts signal delivery, which
-// itself returns a callback that stops signal handling.
+// StartSignalForwarding ensures that synchronous signals are passed to the
+// given handler function and returns a callback that stops signal delivery.
//
// Note that this function permanently takes over signal handling. After the
// stop callback, signals revert to the default Go runtime behavior, which
// cannot be overridden with external calls to signal.Notify.
-func PrepareHandler(handler func(linux.Signal)) func() func() {
- start := make(chan struct{})
+func StartSignalForwarding(handler func(linux.Signal)) func() {
stop := make(chan struct{})
done := make(chan struct{})
@@ -128,13 +88,10 @@ func PrepareHandler(handler func(linux.Signal)) func() func() {
signal.Notify(sigchan, syscall.Signal(sig))
}
// Start up our listener.
- go handleSignals(sigchans, handler, start, stop, done) // S/R-SAFE: synchronized by Kernel.extMu.
+ go handleSignals(sigchans, handler, stop, done) // S/R-SAFE: synchronized by Kernel.extMu.
- return func() func() {
- close(start)
- return func() {
- close(stop)
- <-done
- }
+ return func() {
+ close(stop)
+ <-done
}
}
diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sentry/sighandling/sighandling_unsafe.go
index c303435d5..1ebe22d34 100644
--- a/pkg/sentry/sighandling/sighandling_unsafe.go
+++ b/pkg/sentry/sighandling/sighandling_unsafe.go
@@ -15,8 +15,6 @@
package sighandling
import (
- "fmt"
- "runtime"
"syscall"
"unsafe"
@@ -48,27 +46,3 @@ func IgnoreChildStop() error {
return nil
}
-
-// dieFromSignal kills the current process with sig.
-//
-// Preconditions: The default action of sig is termination.
-func dieFromSignal(sig linux.Signal) {
- runtime.LockOSThread()
- defer runtime.UnlockOSThread()
-
- sa := sigaction{handler: linux.SIG_DFL}
- if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, linux.SignalSetSize, 0, 0); e != 0 {
- panic(fmt.Sprintf("rt_sigaction failed: %v", e))
- }
-
- set := linux.MakeSignalSet(sig)
- if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGPROCMASK, linux.SIG_UNBLOCK, uintptr(unsafe.Pointer(&set)), 0, linux.SignalSetSize, 0, 0); e != 0 {
- panic(fmt.Sprintf("rt_sigprocmask failed: %v", e))
- }
-
- if err := syscall.Tgkill(syscall.Getpid(), syscall.Gettid(), syscall.Signal(sig)); err != nil {
- panic(fmt.Sprintf("tgkill failed: %v", err))
- }
-
- panic("failed to die")
-}
diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD
index 4a6e83a8b..357517ed4 100644
--- a/pkg/sentry/socket/control/BUILD
+++ b/pkg/sentry/socket/control/BUILD
@@ -17,6 +17,7 @@ go_library(
"//pkg/sentry/fs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/socket",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usermem",
"//pkg/syserror",
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 4e95101b7..af1a4e95f 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
@@ -194,15 +195,15 @@ func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([
// the available space, we must align down.
//
// align must be >= 4 and each data int32 is 4 bytes. The length of the
- // header is already aligned, so if we align to the with of the data there
+ // header is already aligned, so if we align to the width of the data there
// are two cases:
// 1. The aligned length is less than the length of the header. The
// unaligned length was also less than the length of the header, so we
// can't write anything.
// 2. The aligned length is greater than or equal to the length of the
- // header. We can write the header plus zero or more datas. We can't write
- // a partial int32, so the length of the message will be
- // min(aligned length, header + datas).
+ // header. We can write the header plus zero or more bytes of data. We can't
+ // write a partial int32, so the length of the message will be
+ // min(aligned length, header + data).
if space < linux.SizeOfControlMessageHeader {
flags |= linux.MSG_CTRUNC
return buf, flags
@@ -239,12 +240,12 @@ func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interf
buf = binary.Marshal(buf, usermem.ByteOrder, data)
- // Check if we went over.
+ // If the control message data brought us over capacity, omit it.
if cap(buf) != cap(ob) {
return hdrBuf
}
- // Fix up length.
+ // Update control message length to include data.
putUint64(ob, uint64(len(buf)-len(ob)))
return alignSlice(buf, align)
@@ -320,35 +321,109 @@ func PackInq(t *kernel.Task, inq int32, buf []byte) []byte {
buf,
linux.SOL_TCP,
linux.TCP_INQ,
- 4,
+ t.Arch().Width(),
inq,
)
}
+// PackTOS packs an IP_TOS socket control message.
+func PackTOS(t *kernel.Task, tos int8, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IP,
+ linux.IP_TOS,
+ t.Arch().Width(),
+ tos,
+ )
+}
+
+// PackTClass packs an IPV6_TCLASS socket control message.
+func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IPV6,
+ linux.IPV6_TCLASS,
+ t.Arch().Width(),
+ tClass,
+ )
+}
+
+// PackControlMessages packs control messages into the given buffer.
+//
+// We skip control messages specific to Unix domain sockets.
+//
+// Note that some control messages may be truncated if they do not fit under
+// the capacity of buf.
+func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byte) []byte {
+ if cmsgs.IP.HasTimestamp {
+ buf = PackTimestamp(t, cmsgs.IP.Timestamp, buf)
+ }
+
+ if cmsgs.IP.HasInq {
+ // In Linux, TCP_CM_INQ is added after SO_TIMESTAMP.
+ buf = PackInq(t, cmsgs.IP.Inq, buf)
+ }
+
+ if cmsgs.IP.HasTOS {
+ buf = PackTOS(t, cmsgs.IP.TOS, buf)
+ }
+
+ if cmsgs.IP.HasTClass {
+ buf = PackTClass(t, cmsgs.IP.TClass, buf)
+ }
+
+ return buf
+}
+
+// cmsgSpace is equivalent to CMSG_SPACE in Linux.
+func cmsgSpace(t *kernel.Task, dataLen int) int {
+ return linux.SizeOfControlMessageHeader + AlignUp(dataLen, t.Arch().Width())
+}
+
+// CmsgsSpace returns the number of bytes needed to fit the control messages
+// represented in cmsgs.
+func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
+ space := 0
+
+ if cmsgs.IP.HasTimestamp {
+ space += cmsgSpace(t, linux.SizeOfTimeval)
+ }
+
+ if cmsgs.IP.HasInq {
+ space += cmsgSpace(t, linux.SizeOfControlMessageInq)
+ }
+
+ if cmsgs.IP.HasTOS {
+ space += cmsgSpace(t, linux.SizeOfControlMessageTOS)
+ }
+
+ if cmsgs.IP.HasTClass {
+ space += cmsgSpace(t, linux.SizeOfControlMessageTClass)
+ }
+
+ return space
+}
+
// Parse parses a raw socket control message into portable objects.
-func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (transport.ControlMessages, error) {
+func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) {
var (
- fds linux.ControlMessageRights
- haveCreds bool
- creds linux.ControlMessageCredentials
+ cmsgs socket.ControlMessages
+ fds linux.ControlMessageRights
)
for i := 0; i < len(buf); {
if i+linux.SizeOfControlMessageHeader > len(buf) {
- return transport.ControlMessages{}, syserror.EINVAL
+ return cmsgs, syserror.EINVAL
}
var h linux.ControlMessageHeader
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], usermem.ByteOrder, &h)
if h.Length < uint64(linux.SizeOfControlMessageHeader) {
- return transport.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, syserror.EINVAL
}
if h.Length > uint64(len(buf)-i) {
- return transport.ControlMessages{}, syserror.EINVAL
- }
- if h.Level != linux.SOL_SOCKET {
- return transport.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, syserror.EINVAL
}
i += linux.SizeOfControlMessageHeader
@@ -358,59 +433,79 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (transport.
// sizeof(long) in CMSG_ALIGN.
width := t.Arch().Width()
- switch h.Type {
- case linux.SCM_RIGHTS:
- rightsSize := AlignDown(length, linux.SizeOfControlMessageRight)
- numRights := rightsSize / linux.SizeOfControlMessageRight
-
- if len(fds)+numRights > linux.SCM_MAX_FD {
- return transport.ControlMessages{}, syserror.EINVAL
+ switch h.Level {
+ case linux.SOL_SOCKET:
+ switch h.Type {
+ case linux.SCM_RIGHTS:
+ rightsSize := AlignDown(length, linux.SizeOfControlMessageRight)
+ numRights := rightsSize / linux.SizeOfControlMessageRight
+
+ if len(fds)+numRights > linux.SCM_MAX_FD {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight {
+ fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight])))
+ }
+
+ i += AlignUp(length, width)
+
+ case linux.SCM_CREDENTIALS:
+ if length < linux.SizeOfControlMessageCredentials {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ var creds linux.ControlMessageCredentials
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds)
+ scmCreds, err := NewSCMCredentials(t, creds)
+ if err != nil {
+ return socket.ControlMessages{}, err
+ }
+ cmsgs.Unix.Credentials = scmCreds
+ i += AlignUp(length, width)
+
+ default:
+ // Unknown message type.
+ return socket.ControlMessages{}, syserror.EINVAL
}
-
- for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight {
- fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight])))
+ case linux.SOL_IP:
+ switch h.Type {
+ case linux.IP_TOS:
+ cmsgs.IP.HasTOS = true
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS)
+ i += AlignUp(length, width)
+
+ default:
+ return socket.ControlMessages{}, syserror.EINVAL
}
-
- i += AlignUp(length, width)
-
- case linux.SCM_CREDENTIALS:
- if length < linux.SizeOfControlMessageCredentials {
- return transport.ControlMessages{}, syserror.EINVAL
+ case linux.SOL_IPV6:
+ switch h.Type {
+ case linux.IPV6_TCLASS:
+ cmsgs.IP.HasTClass = true
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass)
+ i += AlignUp(length, width)
+
+ default:
+ return socket.ControlMessages{}, syserror.EINVAL
}
-
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds)
- haveCreds = true
- i += AlignUp(length, width)
-
default:
- // Unknown message type.
- return transport.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, syserror.EINVAL
}
}
- var credentials SCMCredentials
- if haveCreds {
- var err error
- if credentials, err = NewSCMCredentials(t, creds); err != nil {
- return transport.ControlMessages{}, err
- }
- } else {
- credentials = makeCreds(t, socketOrEndpoint)
+ if cmsgs.Unix.Credentials == nil {
+ cmsgs.Unix.Credentials = makeCreds(t, socketOrEndpoint)
}
- var rights SCMRights
if len(fds) > 0 {
- var err error
- if rights, err = NewSCMRights(t, fds); err != nil {
- return transport.ControlMessages{}, err
+ rights, err := NewSCMRights(t, fds)
+ if err != nil {
+ return socket.ControlMessages{}, err
}
+ cmsgs.Unix.Rights = rights
}
- if credentials == nil && rights == nil {
- return transport.ControlMessages{}, nil
- }
-
- return transport.ControlMessages{Credentials: credentials, Rights: rights}, nil
+ return cmsgs, nil
}
func makeCreds(t *kernel.Task, socketOrEndpoint interface{}) SCMCredentials {
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index 4d174dda4..4c44c7c0f 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -29,9 +29,12 @@ go_library(
"//pkg/sentry/kernel/time",
"//pkg/sentry/safemem",
"//pkg/sentry/socket",
+ "//pkg/sentry/socket/control",
"//pkg/sentry/usermem",
"//pkg/syserr",
"//pkg/syserror",
+ "//pkg/tcpip/stack",
"//pkg/waiter",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 92beb1bcf..c957b0f1d 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -18,6 +18,7 @@ import (
"fmt"
"syscall"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/fdnotifier"
@@ -29,6 +30,7 @@ import (
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
@@ -41,6 +43,10 @@ const (
// sizeofSockaddr is the size in bytes of the largest sockaddr type
// supported by this package.
sizeofSockaddr = syscall.SizeofSockaddrInet6 // sizeof(sockaddr_in6) > sizeof(sockaddr_in)
+
+ // maxControlLen is the maximum size of a control message buffer used in a
+ // recvmsg or sendmsg syscall.
+ maxControlLen = 1024
)
// socketOperations implements fs.FileOperations and socket.Socket for a socket
@@ -281,26 +287,32 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt
// Whitelist options and constrain option length.
var optlen int
switch level {
- case syscall.SOL_IPV6:
+ case linux.SOL_IP:
+ switch name {
+ case linux.IP_TOS, linux.IP_RECVTOS:
+ optlen = sizeofInt32
+ }
+ case linux.SOL_IPV6:
switch name {
- case syscall.IPV6_V6ONLY:
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
optlen = sizeofInt32
}
- case syscall.SOL_SOCKET:
+ case linux.SOL_SOCKET:
switch name {
- case syscall.SO_ERROR, syscall.SO_KEEPALIVE, syscall.SO_SNDBUF, syscall.SO_RCVBUF, syscall.SO_REUSEADDR:
+ case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR:
optlen = sizeofInt32
- case syscall.SO_LINGER:
+ case linux.SO_LINGER:
optlen = syscall.SizeofLinger
}
- case syscall.SOL_TCP:
+ case linux.SOL_TCP:
switch name {
- case syscall.TCP_NODELAY:
+ case linux.TCP_NODELAY:
optlen = sizeofInt32
- case syscall.TCP_INFO:
+ case linux.TCP_INFO:
optlen = int(linux.SizeOfTCPInfo)
}
}
+
if optlen == 0 {
return nil, syserr.ErrProtocolNotAvailable // ENOPROTOOPT
}
@@ -320,19 +332,24 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
// Whitelist options and constrain option length.
var optlen int
switch level {
- case syscall.SOL_IPV6:
+ case linux.SOL_IP:
+ switch name {
+ case linux.IP_TOS, linux.IP_RECVTOS:
+ optlen = sizeofInt32
+ }
+ case linux.SOL_IPV6:
switch name {
- case syscall.IPV6_V6ONLY:
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
optlen = sizeofInt32
}
- case syscall.SOL_SOCKET:
+ case linux.SOL_SOCKET:
switch name {
- case syscall.SO_SNDBUF, syscall.SO_RCVBUF, syscall.SO_REUSEADDR:
+ case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR:
optlen = sizeofInt32
}
- case syscall.SOL_TCP:
+ case linux.SOL_TCP:
switch name {
- case syscall.TCP_NODELAY:
+ case linux.TCP_NODELAY:
optlen = sizeofInt32
}
}
@@ -354,11 +371,11 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
}
// RecvMsg implements socket.Socket.RecvMsg.
-func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
// Whitelist flags.
//
// FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary
- // messages that netstack/tcpip/transport/unix doesn't understand. Kill the
+ // messages that gvisor/pkg/tcpip/transport/unix doesn't understand. Kill the
// Socket interface's dependence on netstack.
if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
@@ -370,6 +387,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
senderAddrBuf = make([]byte, sizeofSockaddr)
}
+ var controlBuf []byte
var msgFlags int
recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
@@ -384,11 +402,6 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
// We always do a non-blocking recv*().
sysflags := flags | syscall.MSG_DONTWAIT
- if dsts.NumBlocks() == 1 {
- // Skip allocating []syscall.Iovec.
- return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddrBuf)
- }
-
iovs := iovecsFromBlockSeq(dsts)
msg := syscall.Msghdr{
Iov: &iovs[0],
@@ -398,12 +411,21 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
msg.Name = &senderAddrBuf[0]
msg.Namelen = uint32(len(senderAddrBuf))
}
+ if controlLen > 0 {
+ if controlLen > maxControlLen {
+ controlLen = maxControlLen
+ }
+ controlBuf = make([]byte, controlLen)
+ msg.Control = &controlBuf[0]
+ msg.Controllen = controlLen
+ }
n, err := recvmsg(s.fd, &msg, sysflags)
if err != nil {
return 0, err
}
senderAddrBuf = senderAddrBuf[:msg.Namelen]
msgFlags = int(msg.Flags)
+ controlLen = uint64(msg.Controllen)
return n, nil
})
@@ -429,14 +451,38 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
n, err = dst.CopyOutFrom(t, recvmsgToBlocks)
}
}
-
- // We don't allow control messages.
- msgFlags &^= linux.MSG_CTRUNC
+ if err != nil {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ }
if senderRequested {
senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf)
}
- return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), socket.ControlMessages{}, syserr.FromError(err)
+
+ unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf[:controlLen])
+ if err != nil {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ }
+
+ controlMessages := socket.ControlMessages{}
+ for _, unixCmsg := range unixControlMessages {
+ switch unixCmsg.Header.Level {
+ case syscall.SOL_IP:
+ switch unixCmsg.Header.Type {
+ case syscall.IP_TOS:
+ controlMessages.IP.HasTOS = true
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS)
+ }
+ case syscall.SOL_IPV6:
+ switch unixCmsg.Header.Type {
+ case syscall.IPV6_TCLASS:
+ controlMessages.IP.HasTClass = true
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass)
+ }
+ }
+ }
+
+ return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), controlMessages, nil
}
// SendMsg implements socket.Socket.SendMsg.
@@ -446,6 +492,14 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
return 0, syserr.ErrInvalidArgument
}
+ space := uint64(control.CmsgsSpace(t, controlMessages))
+ if space > maxControlLen {
+ space = maxControlLen
+ }
+ controlBuf := make([]byte, 0, space)
+ // PackControlMessages will append up to space bytes to controlBuf.
+ controlBuf = control.PackControlMessages(t, controlMessages, controlBuf)
+
sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) {
// Refuse to do anything if any part of src.Addrs was unusable.
if uint64(src.NumBytes()) != srcs.NumBytes() {
@@ -458,7 +512,7 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
// We always do a non-blocking send*().
sysflags := flags | syscall.MSG_DONTWAIT
- if srcs.NumBlocks() == 1 {
+ if srcs.NumBlocks() == 1 && len(controlBuf) == 0 {
// Skip allocating []syscall.Iovec.
src := srcs.Head()
n, _, errno := syscall.Syscall6(syscall.SYS_SENDTO, uintptr(s.fd), src.Addr(), uintptr(src.Len()), uintptr(sysflags), uintptr(firstBytePtr(to)), uintptr(len(to)))
@@ -477,6 +531,10 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
msg.Name = &to[0]
msg.Namelen = uint32(len(to))
}
+ if len(controlBuf) != 0 {
+ msg.Control = &controlBuf[0]
+ msg.Controllen = uint64(len(controlBuf))
+ }
return sendmsg(s.fd, &msg, sysflags)
})
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index 3a4fdec47..e67b46c9e 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -16,8 +16,11 @@ package hostinet
import (
"fmt"
+ "io"
"io/ioutil"
"os"
+ "reflect"
+ "strconv"
"strings"
"syscall"
@@ -26,7 +29,9 @@ import (
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
var defaultRecvBufSize = inet.TCPBufferSize{
@@ -51,6 +56,8 @@ type Stack struct {
tcpRecvBufSize inet.TCPBufferSize
tcpSendBufSize inet.TCPBufferSize
tcpSACKEnabled bool
+ netDevFile *os.File
+ netSNMPFile *os.File
}
// NewStack returns an empty Stack containing no configuration.
@@ -98,6 +105,18 @@ func (s *Stack) Configure() error {
log.Warningf("Failed to read if TCP SACK if enabled, setting to true")
}
+ if f, err := os.Open("/proc/net/dev"); err != nil {
+ log.Warningf("Failed to open /proc/net/dev: %v", err)
+ } else {
+ s.netDevFile = f
+ }
+
+ if f, err := os.Open("/proc/net/snmp"); err != nil {
+ log.Warningf("Failed to open /proc/net/snmp: %v", err)
+ } else {
+ s.netSNMPFile = f
+ }
+
return nil
}
@@ -326,9 +345,95 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
return syserror.EACCES
}
+// getLine reads one line from proc file, with specified prefix.
+// The last argument, withHeader, specifies if it contains line header.
+func getLine(f *os.File, prefix string, withHeader bool) string {
+ data := make([]byte, 4096)
+
+ if _, err := f.Seek(0, 0); err != nil {
+ return ""
+ }
+
+ if _, err := io.ReadFull(f, data); err != io.ErrUnexpectedEOF {
+ return ""
+ }
+
+ prefix = prefix + ":"
+ lines := strings.Split(string(data), "\n")
+ for _, l := range lines {
+ l = strings.TrimSpace(l)
+ if strings.HasPrefix(l, prefix) {
+ if withHeader {
+ withHeader = false
+ continue
+ }
+ return l
+ }
+ }
+ return ""
+}
+
+func toSlice(i interface{}) []uint64 {
+ v := reflect.Indirect(reflect.ValueOf(i))
+ return v.Slice(0, v.Len()).Interface().([]uint64)
+}
+
// Statistics implements inet.Stack.Statistics.
func (s *Stack) Statistics(stat interface{}, arg string) error {
- return syserror.EOPNOTSUPP
+ var (
+ snmpTCP bool
+ rawLine string
+ sliceStat []uint64
+ )
+
+ switch stat.(type) {
+ case *inet.StatDev:
+ if s.netDevFile == nil {
+ return fmt.Errorf("/proc/net/dev is not opened for hostinet")
+ }
+ rawLine = getLine(s.netDevFile, arg, false /* with no header */)
+ case *inet.StatSNMPIP, *inet.StatSNMPICMP, *inet.StatSNMPICMPMSG, *inet.StatSNMPTCP, *inet.StatSNMPUDP, *inet.StatSNMPUDPLite:
+ if s.netSNMPFile == nil {
+ return fmt.Errorf("/proc/net/snmp is not opened for hostinet")
+ }
+ rawLine = getLine(s.netSNMPFile, arg, true)
+ default:
+ return syserr.ErrEndpointOperation.ToError()
+ }
+
+ if rawLine == "" {
+ return fmt.Errorf("Failed to get raw line")
+ }
+
+ parts := strings.SplitN(rawLine, ":", 2)
+ if len(parts) != 2 {
+ return fmt.Errorf("Failed to get prefix from: %q", rawLine)
+ }
+
+ sliceStat = toSlice(stat)
+ fields := strings.Fields(strings.TrimSpace(parts[1]))
+ if len(fields) != len(sliceStat) {
+ return fmt.Errorf("Failed to parse fields: %q", rawLine)
+ }
+ if _, ok := stat.(*inet.StatSNMPTCP); ok {
+ snmpTCP = true
+ }
+ for i := 0; i < len(sliceStat); i++ {
+ var err error
+ if snmpTCP && i == 3 {
+ var tmp int64
+ // MaxConn field is signed, RFC 2012.
+ tmp, err = strconv.ParseInt(fields[i], 10, 64)
+ sliceStat[i] = uint64(tmp) // Convert back to int before use.
+ } else {
+ sliceStat[i], err = strconv.ParseUint(fields[i], 10, 64)
+ }
+ if err != nil {
+ return fmt.Errorf("Failed to parse field %d from: %q, %v", i, rawLine, err)
+ }
+ }
+
+ return nil
}
// RouteTable implements inet.Stack.RouteTable.
@@ -338,3 +443,12 @@ func (s *Stack) RouteTable() []inet.Route {
// Resume implements inet.Stack.Resume.
func (s *Stack) Resume() {}
+
+// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints.
+func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil }
+
+// CleanupEndpoints implements inet.Stack.CleanupEndpoints.
+func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil }
+
+// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
+func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {}
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index f95803f91..79589e3c8 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -20,6 +20,7 @@ go_library(
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
"//pkg/sentry/safemem",
"//pkg/sentry/socket",
diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go
index 689cad997..be005df24 100644
--- a/pkg/sentry/socket/netlink/provider.go
+++ b/pkg/sentry/socket/netlink/provider.go
@@ -30,6 +30,13 @@ type Protocol interface {
// Protocol returns the Linux netlink protocol value.
Protocol() int
+ // CanSend returns true if this protocol may ever send messages.
+ //
+ // TODO(gvisor.dev/issue/1119): This is a workaround to allow
+ // advertising support for otherwise unimplemented features on sockets
+ // that will never send messages, thus making those features no-ops.
+ CanSend() bool
+
// ProcessMessage processes a single message from userspace.
//
// If err == nil, any messages added to ms will be sent back to the
diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go
index cc70ac237..6b4a0ecf4 100644
--- a/pkg/sentry/socket/netlink/route/protocol.go
+++ b/pkg/sentry/socket/netlink/route/protocol.go
@@ -61,6 +61,11 @@ func (p *Protocol) Protocol() int {
return linux.NETLINK_ROUTE
}
+// CanSend implements netlink.Protocol.CanSend.
+func (p *Protocol) CanSend() bool {
+ return true
+}
+
// dumpLinks handles RTM_GETLINK + NLM_F_DUMP requests.
func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error {
// NLM_F_DUMP + RTM_GETLINK messages are supposed to include an
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index b2732ca29..4a1b87a9a 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -27,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/socket"
@@ -53,6 +54,8 @@ const (
maxSendBufferSize = 4 << 20 // 4MB
)
+var errNoFilter = syserr.New("no filter attached", linux.ENOENT)
+
// netlinkSocketDevice is the netlink socket virtual device.
var netlinkSocketDevice = device.NewAnonDevice()
@@ -61,7 +64,7 @@ var netlinkSocketDevice = device.NewAnonDevice()
// This implementation only supports userspace sending and receiving messages
// to/from the kernel.
//
-// Socket implements socket.Socket.
+// Socket implements socket.Socket and transport.Credentialer.
//
// +stateify savable
type Socket struct {
@@ -104,9 +107,19 @@ type Socket struct {
// sendBufferSize is the send buffer "size". We don't actually have a
// fixed buffer but only consume this many bytes.
sendBufferSize uint32
+
+ // passcred indicates if this socket wants SCM credentials.
+ passcred bool
+
+ // filter indicates that this socket has a BPF filter "installed".
+ //
+ // TODO(gvisor.dev/issue/1119): We don't actually support filtering,
+ // this is just bookkeeping for tracking add/remove.
+ filter bool
}
var _ socket.Socket = (*Socket)(nil)
+var _ transport.Credentialer = (*Socket)(nil)
// NewSocket creates a new Socket.
func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) {
@@ -172,6 +185,22 @@ func (s *Socket) EventUnregister(e *waiter.Entry) {
s.ep.EventUnregister(e)
}
+// Passcred implements transport.Credentialer.Passcred.
+func (s *Socket) Passcred() bool {
+ s.mu.Lock()
+ passcred := s.passcred
+ s.mu.Unlock()
+ return passcred
+}
+
+// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
+func (s *Socket) ConnectedPasscred() bool {
+ // This socket is connected to the kernel, which doesn't need creds.
+ //
+ // This is arbitrary, as ConnectedPasscred on this type has no callers.
+ return false
+}
+
// Ioctl implements fs.FileOperations.Ioctl.
func (*Socket) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArguments) (uintptr, error) {
// TODO(b/68878065): no ioctls supported.
@@ -309,9 +338,20 @@ func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.
// We don't have limit on receiving size.
return int32(math.MaxInt32), nil
+ case linux.SO_PASSCRED:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ var passcred int32
+ if s.Passcred() {
+ passcred = 1
+ }
+ return passcred, nil
+
default:
socket.GetSockOptEmitUnimplementedEvent(t, name)
}
+
case linux.SOL_NETLINK:
switch name {
case linux.NETLINK_BROADCAST_ERROR,
@@ -348,6 +388,7 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy
s.sendBufferSize = size
s.mu.Unlock()
return nil
+
case linux.SO_RCVBUF:
if len(opt) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -355,6 +396,52 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy
// We don't have limit on receiving size. So just accept anything as
// valid for compatibility.
return nil
+
+ case linux.SO_PASSCRED:
+ if len(opt) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ passcred := usermem.ByteOrder.Uint32(opt)
+
+ s.mu.Lock()
+ s.passcred = passcred != 0
+ s.mu.Unlock()
+ return nil
+
+ case linux.SO_ATTACH_FILTER:
+ // TODO(gvisor.dev/issue/1119): We don't actually
+ // support filtering. If this socket can't ever send
+ // messages, then there is nothing to filter and we can
+ // advertise support. Otherwise, be conservative and
+ // return an error.
+ if s.protocol.CanSend() {
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ return syserr.ErrProtocolNotAvailable
+ }
+
+ s.mu.Lock()
+ s.filter = true
+ s.mu.Unlock()
+ return nil
+
+ case linux.SO_DETACH_FILTER:
+ // TODO(gvisor.dev/issue/1119): See above.
+ if s.protocol.CanSend() {
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ return syserr.ErrProtocolNotAvailable
+ }
+
+ s.mu.Lock()
+ filter := s.filter
+ s.filter = false
+ s.mu.Unlock()
+
+ if !filter {
+ return errNoFilter
+ }
+
+ return nil
+
default:
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
@@ -483,6 +570,26 @@ func (s *Socket) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _
})
}
+// kernelSCM implements control.SCMCredentials with credentials that represent
+// the kernel itself rather than a Task.
+//
+// +stateify savable
+type kernelSCM struct{}
+
+// Equals implements transport.CredentialsControlMessage.Equals.
+func (kernelSCM) Equals(oc transport.CredentialsControlMessage) bool {
+ _, ok := oc.(kernelSCM)
+ return ok
+}
+
+// Credentials implements control.SCMCredentials.Credentials.
+func (kernelSCM) Credentials(*kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) {
+ return 0, auth.RootUID, auth.RootGID
+}
+
+// kernelCreds is the concrete version of kernelSCM used in all creds.
+var kernelCreds = &kernelSCM{}
+
// sendResponse sends the response messages in ms back to userspace.
func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error {
// Linux combines multiple netlink messages into a single datagram.
@@ -491,10 +598,15 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error
bufs = append(bufs, m.Finalize())
}
+ // All messages are from the kernel.
+ cms := transport.ControlMessages{
+ Credentials: kernelCreds,
+ }
+
if len(bufs) > 0 {
// RecvMsg never receives the address, so we don't need to send
// one.
- _, notify, err := s.connection.Send(bufs, transport.ControlMessages{}, tcpip.FullAddress{})
+ _, notify, err := s.connection.Send(bufs, cms, tcpip.FullAddress{})
// If the buffer is full, we simply drop messages, just like
// Linux.
if err != nil && err != syserr.ErrWouldBlock {
@@ -521,7 +633,7 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error
// Add the dump_done_errno payload.
m.Put(int64(0))
- _, notify, err := s.connection.Send([][]byte{m.Finalize()}, transport.ControlMessages{}, tcpip.FullAddress{})
+ _, notify, err := s.connection.Send([][]byte{m.Finalize()}, cms, tcpip.FullAddress{})
if err != nil && err != syserr.ErrWouldBlock {
return err
}
diff --git a/pkg/sentry/socket/netlink/uevent/BUILD b/pkg/sentry/socket/netlink/uevent/BUILD
new file mode 100644
index 000000000..0777f3baf
--- /dev/null
+++ b/pkg/sentry/socket/netlink/uevent/BUILD
@@ -0,0 +1,17 @@
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "uevent",
+ srcs = ["protocol.go"],
+ importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/uevent",
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/context",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/socket/netlink",
+ "//pkg/syserr",
+ ],
+)
diff --git a/pkg/sentry/socket/netlink/uevent/protocol.go b/pkg/sentry/socket/netlink/uevent/protocol.go
new file mode 100644
index 000000000..b5d7808d7
--- /dev/null
+++ b/pkg/sentry/socket/netlink/uevent/protocol.go
@@ -0,0 +1,60 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package uevent provides a NETLINK_KOBJECT_UEVENT socket protocol.
+//
+// NETLINK_KOBJECT_UEVENT sockets send udev-style device events. gVisor does
+// not support any device events, so these sockets never send any messages.
+package uevent
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netlink"
+ "gvisor.dev/gvisor/pkg/syserr"
+)
+
+// Protocol implements netlink.Protocol.
+//
+// +stateify savable
+type Protocol struct{}
+
+var _ netlink.Protocol = (*Protocol)(nil)
+
+// NewProtocol creates a NETLINK_KOBJECT_UEVENT netlink.Protocol.
+func NewProtocol(t *kernel.Task) (netlink.Protocol, *syserr.Error) {
+ return &Protocol{}, nil
+}
+
+// Protocol implements netlink.Protocol.Protocol.
+func (p *Protocol) Protocol() int {
+ return linux.NETLINK_KOBJECT_UEVENT
+}
+
+// CanSend implements netlink.Protocol.CanSend.
+func (p *Protocol) CanSend() bool {
+ return false
+}
+
+// ProcessMessage implements netlink.Protocol.ProcessMessage.
+func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error {
+ // Silently ignore all messages.
+ return nil
+}
+
+// init registers the NETLINK_KOBJECT_UEVENT provider.
+func init() {
+ netlink.RegisterProvider(linux.NETLINK_KOBJECT_UEVENT, NewProtocol)
+}
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 0ae573b45..140851c17 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -53,6 +53,7 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -148,6 +149,10 @@ var Metrics = tcpip.Stats{
TCP: tcpip.TCPStats{
ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."),
PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."),
+ CurrentEstablished: mustCreateMetric("/netstack/tcp/current_established", "Number of connections in either ESTABLISHED or CLOSE-WAIT state now."),
+ EstablishedResets: mustCreateMetric("/netstack/tcp/established_resets", "Number of times TCP connections have made a direct transition to the CLOSED state from either the ESTABLISHED state or the CLOSE-WAIT state"),
+ EstablishedClosed: mustCreateMetric("/netstack/tcp/established_closed", "number of times established TCP connections made a transition to CLOSED state."),
+ EstablishedTimedout: mustCreateMetric("/netstack/tcp/established_timedout", "Number of times an established connection was reset because of keep-alive time out."),
ListenOverflowSynDrop: mustCreateMetric("/netstack/tcp/listen_overflow_syn_drop", "Number of times the listen queue overflowed and a SYN was dropped."),
ListenOverflowAckDrop: mustCreateMetric("/netstack/tcp/listen_overflow_ack_drop", "Number of times the listen queue overflowed and the final ACK in the handshake was dropped."),
ListenOverflowSynCookieSent: mustCreateMetric("/netstack/tcp/listen_overflow_syn_cookie_sent", "Number of times a SYN cookie was sent."),
@@ -278,7 +283,7 @@ type SocketOperations struct {
// New creates a new endpoint socket.
func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) {
if skType == linux.SOCK_STREAM {
- if err := endpoint.SetSockOpt(tcpip.DelayOption(1)); err != nil {
+ if err := endpoint.SetSockOptInt(tcpip.DelayOption, 1); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
}
@@ -296,6 +301,7 @@ func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue
var sockAddrInetSize = int(binary.Size(linux.SockAddrInet{}))
var sockAddrInet6Size = int(binary.Size(linux.SockAddrInet6{}))
+var sockAddrLinkSize = int(binary.Size(linux.SockAddrLink{}))
// bytesToIPAddress converts an IPv4 or IPv6 address from the user to the
// netstack representation taking any addresses into account.
@@ -307,12 +313,12 @@ func bytesToIPAddress(addr []byte) tcpip.Address {
}
// AddressAndFamily reads an sockaddr struct from the given address and
-// converts it to the FullAddress format. It supports AF_UNIX, AF_INET and
-// AF_INET6 addresses.
+// converts it to the FullAddress format. It supports AF_UNIX, AF_INET,
+// AF_INET6, and AF_PACKET addresses.
//
// strict indicates whether addresses with the AF_UNSPEC family are accepted of not.
//
-// AddressAndFamily returns an address, its family.
+// AddressAndFamily returns an address and its family.
func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, uint16, *syserr.Error) {
// Make sure we have at least 2 bytes for the address family.
if len(addr) < 2 {
@@ -320,7 +326,7 @@ func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress,
}
family := usermem.ByteOrder.Uint16(addr)
- if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) {
+ if family != uint16(sfamily) && (strict || family != linux.AF_UNSPEC) {
return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported
}
@@ -371,6 +377,22 @@ func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress,
}
return out, family, nil
+ case linux.AF_PACKET:
+ var a linux.SockAddrLink
+ if len(addr) < sockAddrLinkSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a)
+ if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+
+ // TODO(b/129292371): Return protocol too.
+ return tcpip.FullAddress{
+ NIC: tcpip.NICID(a.InterfaceIndex),
+ Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ }, family, nil
+
case linux.AF_UNSPEC:
return tcpip.FullAddress{}, family, nil
@@ -1035,8 +1057,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.DelayOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.DelayOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -1105,6 +1127,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return int32(time.Duration(v) / time.Second), nil
+ case linux.TCP_USER_TIMEOUT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.TCPUserTimeoutOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(time.Duration(v) / time.Millisecond), nil
+
case linux.TCP_INFO:
var v tcpip.TCPInfoOption
if err := ep.GetSockOpt(&v); err != nil {
@@ -1153,6 +1187,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
copy(b, v)
return b, nil
+ case linux.TCP_LINGER2:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.TCPLingerTimeoutOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(time.Duration(v) / time.Second), nil
+
default:
emitUnimplementedEventTCP(t, name)
}
@@ -1477,11 +1523,11 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- var o tcpip.DelayOption
+ var o int
if v == 0 {
o = 1
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(o))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.DelayOption, o))
case linux.TCP_CORK:
if len(optVal) < sizeOfInt32 {
@@ -1529,6 +1575,17 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v))))
+ case linux.TCP_USER_TIMEOUT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+ if v < 0 {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPUserTimeoutOption(time.Millisecond * time.Duration(v))))
+
case linux.TCP_CONGESTION:
v := tcpip.CongestionControlOption(optVal)
if err := ep.SetSockOpt(v); err != nil {
@@ -1536,6 +1593,14 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
return nil
+ case linux.TCP_LINGER2:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v))))
+
case linux.TCP_REPAIR_OPTIONS:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1951,12 +2016,14 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32)
return &out, uint32(2 + l)
}
return &out, uint32(3 + l)
+
case linux.AF_INET:
var out linux.SockAddrInet
copy(out.Addr[:], addr.Addr)
out.Family = linux.AF_INET
out.Port = htons(addr.Port)
- return &out, uint32(binary.Size(out))
+ return &out, uint32(sockAddrInetSize)
+
case linux.AF_INET6:
var out linux.SockAddrInet6
if len(addr.Addr) == 4 {
@@ -1972,7 +2039,17 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32)
if isLinkLocal(addr.Addr) {
out.Scope_id = uint32(addr.NIC)
}
- return &out, uint32(binary.Size(out))
+ return &out, uint32(sockAddrInet6Size)
+
+ case linux.AF_PACKET:
+ // TODO(b/129292371): Return protocol too.
+ var out linux.SockAddrLink
+ out.Family = linux.AF_PACKET
+ out.InterfaceIndex = int32(addr.NIC)
+ out.HardwareAddrLen = header.EthernetAddressSize
+ copy(out.HardwareAddr[:], addr.Addr)
+ return &out, uint32(sockAddrLinkSize)
+
default:
return nil, 0
}
diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go
index 357a664cc..2d2c1ba2a 100644
--- a/pkg/sentry/socket/netstack/provider.go
+++ b/pkg/sentry/socket/netstack/provider.go
@@ -62,6 +62,10 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
}
case linux.SOCK_RAW:
+ // TODO(b/142504697): "In order to create a raw socket, a
+ // process must have the CAP_NET_RAW capability in the user
+ // namespace that governs its network namespace." - raw(7)
+
// Raw sockets require CAP_NET_RAW.
creds := auth.CredentialsFromContext(ctx)
if !creds.HasCapability(linux.CAP_NET_RAW) {
@@ -85,7 +89,8 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
return 0, true, syserr.ErrProtocolNotSupported
}
-// Socket creates a new socket object for the AF_INET or AF_INET6 family.
+// Socket creates a new socket object for the AF_INET, AF_INET6, or AF_PACKET
+// family.
func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
// Fail right away if we don't have a stack.
stack := t.NetworkContext()
@@ -99,6 +104,12 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
return nil, nil
}
+ // Packet sockets are handled separately, since they are neither INET
+ // nor INET6 specific.
+ if p.family == linux.AF_PACKET {
+ return packetSocket(t, eps, stype, protocol)
+ }
+
// Figure out the transport protocol.
transProto, associated, err := getTransportProtocol(t, stype, protocol)
if err != nil {
@@ -121,12 +132,47 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
return New(t, p.family, stype, int(transProto), wq, ep)
}
+func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // TODO(b/142504697): "In order to create a packet socket, a process
+ // must have the CAP_NET_RAW capability in the user namespace that
+ // governs its network namespace." - packet(7)
+
+ // Packet sockets require CAP_NET_RAW.
+ creds := auth.CredentialsFromContext(t)
+ if !creds.HasCapability(linux.CAP_NET_RAW) {
+ return nil, syserr.ErrNotPermitted
+ }
+
+ // "cooked" packets don't contain link layer information.
+ var cooked bool
+ switch stype {
+ case linux.SOCK_DGRAM:
+ cooked = true
+ case linux.SOCK_RAW:
+ cooked = false
+ default:
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ // protocol is passed in network byte order, but netstack wants it in
+ // host order.
+ netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol)))
+
+ wq := &waiter.Queue{}
+ ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return New(t, linux.AF_PACKET, stype, protocol, wq, ep)
+}
+
// Pair just returns nil sockets (not supported).
func (*provider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
return nil, nil, nil
}
-// init registers socket providers for AF_INET and AF_INET6.
+// init registers socket providers for AF_INET, AF_INET6, and AF_PACKET.
func init() {
// Providers backed by netstack.
p := []provider{
@@ -138,6 +184,9 @@ func init() {
family: linux.AF_INET6,
netProto: ipv6.ProtocolNumber,
},
+ {
+ family: linux.AF_PACKET,
+ },
}
for i := range p {
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index fda0156e5..a0db2d4fd 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -144,7 +144,98 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
// Statistics implements inet.Stack.Statistics.
func (s *Stack) Statistics(stat interface{}, arg string) error {
- return syserr.ErrEndpointOperation.ToError()
+ switch stats := stat.(type) {
+ case *inet.StatSNMPIP:
+ ip := Metrics.IP
+ *stats = inet.StatSNMPIP{
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/Forwarding.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/DefaultTTL.
+ ip.PacketsReceived.Value(), // InReceives.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/InHdrErrors.
+ ip.InvalidAddressesReceived.Value(), // InAddrErrors.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/ForwDatagrams.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/InUnknownProtos.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/InDiscards.
+ ip.PacketsDelivered.Value(), // InDelivers.
+ ip.PacketsSent.Value(), // OutRequests.
+ ip.OutgoingPacketErrors.Value(), // OutDiscards.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/OutNoRoutes.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmTimeout.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmReqds.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmOKs.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmFails.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/FragOKs.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/FragFails.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/FragCreates.
+ }
+ case *inet.StatSNMPICMP:
+ in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats
+ out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats
+ *stats = inet.StatSNMPICMP{
+ 0, // TODO(gvisor.dev/issue/969): Support Icmp/InMsgs.
+ Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors.
+ 0, // TODO(gvisor.dev/issue/969): Support Icmp/InCsumErrors.
+ in.DstUnreachable.Value(), // InDestUnreachs.
+ in.TimeExceeded.Value(), // InTimeExcds.
+ in.ParamProblem.Value(), // InParmProbs.
+ in.SrcQuench.Value(), // InSrcQuenchs.
+ in.Redirect.Value(), // InRedirects.
+ in.Echo.Value(), // InEchos.
+ in.EchoReply.Value(), // InEchoReps.
+ in.Timestamp.Value(), // InTimestamps.
+ in.TimestampReply.Value(), // InTimestampReps.
+ in.InfoRequest.Value(), // InAddrMasks.
+ in.InfoReply.Value(), // InAddrMaskReps.
+ 0, // TODO(gvisor.dev/issue/969): Support Icmp/OutMsgs.
+ Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors.
+ out.DstUnreachable.Value(), // OutDestUnreachs.
+ out.TimeExceeded.Value(), // OutTimeExcds.
+ out.ParamProblem.Value(), // OutParmProbs.
+ out.SrcQuench.Value(), // OutSrcQuenchs.
+ out.Redirect.Value(), // OutRedirects.
+ out.Echo.Value(), // OutEchos.
+ out.EchoReply.Value(), // OutEchoReps.
+ out.Timestamp.Value(), // OutTimestamps.
+ out.TimestampReply.Value(), // OutTimestampReps.
+ out.InfoRequest.Value(), // OutAddrMasks.
+ out.InfoReply.Value(), // OutAddrMaskReps.
+ }
+ case *inet.StatSNMPTCP:
+ tcp := Metrics.TCP
+ // RFC 2012 (updates 1213): SNMPv2-MIB-TCP.
+ *stats = inet.StatSNMPTCP{
+ 1, // RtoAlgorithm.
+ 200, // RtoMin.
+ 120000, // RtoMax.
+ (1<<64 - 1), // MaxConn.
+ tcp.ActiveConnectionOpenings.Value(), // ActiveOpens.
+ tcp.PassiveConnectionOpenings.Value(), // PassiveOpens.
+ tcp.FailedConnectionAttempts.Value(), // AttemptFails.
+ tcp.EstablishedResets.Value(), // EstabResets.
+ tcp.CurrentEstablished.Value(), // CurrEstab.
+ tcp.ValidSegmentsReceived.Value(), // InSegs.
+ tcp.SegmentsSent.Value(), // OutSegs.
+ tcp.Retransmits.Value(), // RetransSegs.
+ tcp.InvalidSegmentsReceived.Value(), // InErrs.
+ tcp.ResetsSent.Value(), // OutRsts.
+ tcp.ChecksumErrors.Value(), // InCsumErrors.
+ }
+ case *inet.StatSNMPUDP:
+ udp := Metrics.UDP
+ *stats = inet.StatSNMPUDP{
+ udp.PacketsReceived.Value(), // InDatagrams.
+ udp.UnknownPortErrors.Value(), // NoPorts.
+ 0, // TODO(gvisor.dev/issue/969): Support Udp/InErrors.
+ udp.PacketsSent.Value(), // OutDatagrams.
+ udp.ReceiveBufferErrors.Value(), // RcvbufErrors.
+ 0, // TODO(gvisor.dev/issue/969): Support Udp/SndbufErrors.
+ 0, // TODO(gvisor.dev/issue/969): Support Udp/InCsumErrors.
+ 0, // TODO(gvisor.dev/issue/969): Support Udp/IgnoredMulti.
+ }
+ default:
+ return syserr.ErrEndpointOperation.ToError()
+ }
+ return nil
}
// RouteTable implements inet.Stack.RouteTable.
@@ -200,3 +291,18 @@ func (s *Stack) FillDefaultIPTables() {
func (s *Stack) Resume() {
s.Stack.Resume()
}
+
+// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints.
+func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint {
+ return s.Stack.RegisteredEndpoints()
+}
+
+// CleanupEndpoints implements inet.Stack.CleanupEndpoints.
+func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint {
+ return s.Stack.CleanupEndpoints()
+}
+
+// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
+func (s *Stack) RestoreCleanupEndpoints(es []stack.TransportEndpoint) {
+ s.Stack.RestoreCleanupEndpoints(es)
+}
diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD
index 3a6baa308..4668b87d1 100644
--- a/pkg/sentry/socket/rpcinet/BUILD
+++ b/pkg/sentry/socket/rpcinet/BUILD
@@ -37,6 +37,7 @@ go_library(
"//pkg/syserror",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/stack",
"//pkg/unet",
"//pkg/waiter",
],
diff --git a/pkg/sentry/socket/rpcinet/stack.go b/pkg/sentry/socket/rpcinet/stack.go
index 5dcb6b455..f7878a760 100644
--- a/pkg/sentry/socket/rpcinet/stack.go
+++ b/pkg/sentry/socket/rpcinet/stack.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn"
"gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier"
"gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/unet"
)
@@ -165,3 +166,12 @@ func (s *Stack) RouteTable() []inet.Route {
// Resume implements inet.Stack.Resume.
func (s *Stack) Resume() {}
+
+// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints.
+func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil }
+
+// CleanupEndpoints implements inet.Stack.CleanupEndpoints.
+func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil }
+
+// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
+func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {}
diff --git a/pkg/sentry/socket/rpcinet/syscall_rpc.proto b/pkg/sentry/socket/rpcinet/syscall_rpc.proto
index 9586f5923..b677e9eb3 100644
--- a/pkg/sentry/socket/rpcinet/syscall_rpc.proto
+++ b/pkg/sentry/socket/rpcinet/syscall_rpc.proto
@@ -3,7 +3,6 @@ syntax = "proto3";
// package syscall_rpc is a set of networking related system calls that can be
// forwarded to a socket gofer.
//
-// TODO(b/77963526): Document individual RPCs.
package syscall_rpc;
message SendmsgRequest {
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 8c250c325..2389a9cdb 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -43,6 +43,11 @@ type ControlMessages struct {
IP tcpip.ControlMessages
}
+// Release releases Unix domain socket credentials and rights.
+func (c *ControlMessages) Release() {
+ c.Unix.Release()
+}
+
// Socket is the interface containing socket syscalls used by the syscall layer
// to redirect them to the appropriate implementation.
type Socket interface {
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 1aaae8487..885758054 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -118,6 +118,9 @@ func (s *SocketOperations) Endpoint() transport.Endpoint {
func extractPath(sockaddr []byte) (string, *syserr.Error) {
addr, _, err := netstack.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */)
if err != nil {
+ if err == syserr.ErrAddressFamilyNotSupported {
+ err = syserr.ErrInvalidArgument
+ }
return "", err
}
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index 72ebf766d..d46421199 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -14,6 +14,7 @@ go_library(
"open.go",
"poll.go",
"ptrace.go",
+ "select.go",
"signal.go",
"socket.go",
"strace.go",
diff --git a/pkg/sentry/strace/linux64.go b/pkg/sentry/strace/linux64.go
index 5d57b75af..e603f858f 100644
--- a/pkg/sentry/strace/linux64.go
+++ b/pkg/sentry/strace/linux64.go
@@ -40,7 +40,7 @@ var linuxAMD64 = SyscallMap{
20: makeSyscallInfo("writev", FD, WriteIOVec, Hex),
21: makeSyscallInfo("access", Path, Oct),
22: makeSyscallInfo("pipe", PipeFDs),
- 23: makeSyscallInfo("select", Hex, Hex, Hex, Hex, Timeval),
+ 23: makeSyscallInfo("select", Hex, SelectFDSet, SelectFDSet, SelectFDSet, Timeval),
24: makeSyscallInfo("sched_yield"),
25: makeSyscallInfo("mremap", Hex, Hex, Hex, Hex, Hex),
26: makeSyscallInfo("msync", Hex, Hex, Hex),
@@ -287,7 +287,7 @@ var linuxAMD64 = SyscallMap{
267: makeSyscallInfo("readlinkat", FD, Path, ReadBuffer, Hex),
268: makeSyscallInfo("fchmodat", FD, Path, Mode),
269: makeSyscallInfo("faccessat", FD, Path, Oct, Hex),
- 270: makeSyscallInfo("pselect6", Hex, Hex, Hex, Hex, Hex, Hex),
+ 270: makeSyscallInfo("pselect6", Hex, SelectFDSet, SelectFDSet, SelectFDSet, Timespec, SigSet),
271: makeSyscallInfo("ppoll", PollFDs, Hex, Timespec, SigSet, Hex),
272: makeSyscallInfo("unshare", CloneFlags),
273: makeSyscallInfo("set_robust_list", Hex, Hex),
@@ -335,5 +335,33 @@ var linuxAMD64 = SyscallMap{
315: makeSyscallInfo("sched_getattr", Hex, Hex, Hex),
316: makeSyscallInfo("renameat2", FD, Path, Hex, Path, Hex),
317: makeSyscallInfo("seccomp", Hex, Hex, Hex),
+ 318: makeSyscallInfo("getrandom", Hex, Hex, Hex),
+ 319: makeSyscallInfo("memfd_create", Path, Hex), // Not quite a path, but close.
+ 320: makeSyscallInfo("kexec_file_load", FD, FD, Hex, Hex, Hex),
+ 321: makeSyscallInfo("bpf", Hex, Hex, Hex),
+ 322: makeSyscallInfo("execveat", FD, Path, ExecveStringVector, ExecveStringVector, Hex),
+ 323: makeSyscallInfo("userfaultfd", Hex),
+ 324: makeSyscallInfo("membarrier", Hex, Hex),
+ 325: makeSyscallInfo("mlock2", Hex, Hex, Hex),
+ 326: makeSyscallInfo("copy_file_range", FD, Hex, FD, Hex, Hex, Hex),
+ 327: makeSyscallInfo("preadv2", FD, ReadIOVec, Hex, Hex, Hex),
+ 328: makeSyscallInfo("pwritev2", FD, WriteIOVec, Hex, Hex, Hex),
+ 329: makeSyscallInfo("pkey_mprotect", Hex, Hex, Hex, Hex),
+ 330: makeSyscallInfo("pkey_alloc", Hex, Hex),
+ 331: makeSyscallInfo("pkey_free", Hex),
332: makeSyscallInfo("statx", FD, Path, Hex, Hex, Hex),
+ 333: makeSyscallInfo("io_pgetevents", Hex, Hex, Hex, Hex, Timespec, SigSet),
+ 334: makeSyscallInfo("rseq", Hex, Hex, Hex, Hex),
+ 424: makeSyscallInfo("pidfd_send_signal", FD, Signal, Hex, Hex),
+ 425: makeSyscallInfo("io_uring_setup", Hex, Hex),
+ 426: makeSyscallInfo("io_uring_enter", FD, Hex, Hex, Hex, SigSet, Hex),
+ 427: makeSyscallInfo("io_uring_register", FD, Hex, Hex, Hex),
+ 428: makeSyscallInfo("open_tree", FD, Path, Hex),
+ 429: makeSyscallInfo("move_mount", FD, Path, FD, Path, Hex),
+ 430: makeSyscallInfo("fsopen", Path, Hex), // Not quite a path, but close.
+ 431: makeSyscallInfo("fsconfig", FD, Hex, Hex, Hex, Hex),
+ 432: makeSyscallInfo("fsmount", FD, Hex, Hex),
+ 433: makeSyscallInfo("fspick", FD, Path, Hex),
+ 434: makeSyscallInfo("pidfd_open", Hex, Hex),
+ 435: makeSyscallInfo("clone3", Hex, Hex),
}
diff --git a/pkg/sentry/strace/select.go b/pkg/sentry/strace/select.go
new file mode 100644
index 000000000..c77d418e6
--- /dev/null
+++ b/pkg/sentry/strace/select.go
@@ -0,0 +1,56 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+)
+
+func fdsFromSet(t *kernel.Task, set []byte) []int {
+ var fds []int
+ // Append n if the n-th bit is 1.
+ for i, v := range set {
+ for j := 0; j < 8; j++ {
+ if (v>>j)&1 == 1 {
+ fds = append(fds, i*8+j)
+ }
+ }
+ }
+ return fds
+}
+
+func fdSet(t *kernel.Task, nfds int, addr usermem.Addr) string {
+ if nfds < 0 {
+ return fmt.Sprintf("%#x (negative nfds)", addr)
+ }
+ if addr == 0 {
+ return "null"
+ }
+
+ // Calculate the size of the fd set (one bit per fd).
+ nBytes := (nfds + 7) / 8
+ nBitsInLastPartialByte := nfds % 8
+
+ set, err := linux.CopyInFDSet(t, addr, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return fmt.Sprintf("%#x (error decoding fdset: %s)", addr, err)
+ }
+
+ return fmt.Sprintf("%#x %v", addr, fdsFromSet(t, set))
+}
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index 94334f6d2..51f2efb39 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -208,6 +208,15 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64)
i += linux.SizeOfControlMessageHeader
width := t.Arch().Width()
length := int(h.Length) - linux.SizeOfControlMessageHeader
+ if length < 0 {
+ strs = append(strs, fmt.Sprintf(
+ "{level=%s, type=%s, length=%d, content too short}",
+ level,
+ typ,
+ h.Length,
+ ))
+ break
+ }
if skipData {
strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length))
diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go
index 311389547..629c1f308 100644
--- a/pkg/sentry/strace/strace.go
+++ b/pkg/sentry/strace/strace.go
@@ -439,6 +439,8 @@ func (i *SyscallInfo) pre(t *kernel.Task, args arch.SyscallArguments, maximumBlo
output = append(output, capData(t, args[arg-1].Pointer(), args[arg].Pointer()))
case PollFDs:
output = append(output, pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), false))
+ case SelectFDSet:
+ output = append(output, fdSet(t, int(args[0].Int()), args[arg].Pointer()))
case Oct:
output = append(output, "0o"+strconv.FormatUint(args[arg].Uint64(), 8))
case Hex:
diff --git a/pkg/sentry/strace/strace.proto b/pkg/sentry/strace/strace.proto
index 4b2f73a5f..906c52c51 100644
--- a/pkg/sentry/strace/strace.proto
+++ b/pkg/sentry/strace/strace.proto
@@ -32,8 +32,7 @@ message Strace {
}
}
-message StraceEnter {
-}
+message StraceEnter {}
message StraceExit {
// Return value formatted as string.
diff --git a/pkg/sentry/strace/syscalls.go b/pkg/sentry/strace/syscalls.go
index 3c389d375..e5d486c4e 100644
--- a/pkg/sentry/strace/syscalls.go
+++ b/pkg/sentry/strace/syscalls.go
@@ -206,6 +206,10 @@ const (
// PollFDs is an array of struct pollfd. The number of entries in the
// array is in the next argument.
PollFDs
+
+ // SelectFDSet is an fd_set argument in select(2)/pselect(2). The number of
+ // fds represented must be the first argument.
+ SelectFDSet
)
// defaultFormat is the syscall argument format to use if the actual format is
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index fb2c1777f..6766ba587 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -49,6 +49,7 @@ go_library(
"sys_tls.go",
"sys_utsname.go",
"sys_write.go",
+ "sys_xattr.go",
"timespec.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/syscalls/linux",
@@ -79,6 +80,7 @@ go_library(
"//pkg/sentry/kernel/signalfd",
"//pkg/sentry/kernel/time",
"//pkg/sentry/limits",
+ "//pkg/sentry/loader",
"//pkg/sentry/memmap",
"//pkg/sentry/mm",
"//pkg/sentry/safemem",
diff --git a/pkg/sentry/syscalls/linux/flags.go b/pkg/sentry/syscalls/linux/flags.go
index 444f2b004..07961dad9 100644
--- a/pkg/sentry/syscalls/linux/flags.go
+++ b/pkg/sentry/syscalls/linux/flags.go
@@ -50,5 +50,6 @@ func linuxToFlags(mask uint) fs.FileFlags {
Directory: mask&linux.O_DIRECTORY != 0,
Async: mask&linux.O_ASYNC != 0,
LargeFile: mask&linux.O_LARGEFILE != 0,
+ Truncate: mask&linux.O_TRUNC != 0,
}
}
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index b317cb99d..68589a377 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -16,7 +16,12 @@
package linux
const (
- _LINUX_SYSNAME = "Linux"
- _LINUX_RELEASE = "4.4.0"
- _LINUX_VERSION = "#1 SMP Sun Jan 10 15:06:54 PST 2016"
+ // LinuxSysname is the OS name advertised by gVisor.
+ LinuxSysname = "Linux"
+
+ // LinuxRelease is the Linux release version number advertised by gVisor.
+ LinuxRelease = "4.4.0"
+
+ // LinuxVersion is the version info advertised by gVisor.
+ LinuxVersion = "#1 SMP Sun Jan 10 15:06:54 PST 2016"
)
diff --git a/pkg/sentry/syscalls/linux/linux64_amd64.go b/pkg/sentry/syscalls/linux/linux64_amd64.go
index e215ac049..272ae9991 100644
--- a/pkg/sentry/syscalls/linux/linux64_amd64.go
+++ b/pkg/sentry/syscalls/linux/linux64_amd64.go
@@ -34,9 +34,9 @@ var AMD64 = &kernel.SyscallTable{
// guides the interface provided by this syscall table. The build
// version is that for a clean build with default kernel config, at 5
// minutes after v4.4 was tagged.
- Sysname: _LINUX_SYSNAME,
- Release: _LINUX_RELEASE,
- Version: _LINUX_VERSION,
+ Sysname: LinuxSysname,
+ Release: LinuxRelease,
+ Version: LinuxVersion,
},
AuditNumber: linux.AUDIT_ARCH_X86_64,
Table: map[uintptr]kernel.Syscall{
@@ -228,10 +228,10 @@ var AMD64 = &kernel.SyscallTable{
185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil),
186: syscalls.Supported("gettid", Gettid),
187: syscalls.Supported("readahead", Readahead),
- 188: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 188: syscalls.PartiallySupported("setxattr", Setxattr, "Only supported for tmpfs.", nil),
189: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
190: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 191: syscalls.ErrorWithEvent("getxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 191: syscalls.PartiallySupported("getxattr", Getxattr, "Only supported for tmpfs.", nil),
192: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
193: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
194: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
@@ -260,7 +260,7 @@ var AMD64 = &kernel.SyscallTable{
217: syscalls.Supported("getdents64", Getdents64),
218: syscalls.Supported("set_tid_address", SetTidAddress),
219: syscalls.Supported("restart_syscall", RestartSyscall),
- 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), // TODO(b/29354920)
+ 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}),
221: syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil),
222: syscalls.Supported("timer_create", TimerCreate),
223: syscalls.Supported("timer_settime", TimerSettime),
@@ -362,16 +362,36 @@ var AMD64 = &kernel.SyscallTable{
319: syscalls.Supported("memfd_create", MemfdCreate),
320: syscalls.CapError("kexec_file_load", linux.CAP_SYS_BOOT, "", nil),
321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil),
- 322: syscalls.ErrorWithEvent("execveat", syserror.ENOSYS, "", []string{"gvisor.dev/issue/265"}), // TODO(b/118901836)
+ 322: syscalls.Supported("execveat", Execveat),
323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
- 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(b/118904897)
+ 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267)
325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
- // Syscalls after 325 are "backports" from versions of Linux after 4.4.
+ // Syscalls implemented after 325 are "backports" from versions
+ // of Linux after 4.4.
326: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil),
327: syscalls.Supported("preadv2", Preadv2),
328: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil),
+ 329: syscalls.ErrorWithEvent("pkey_mprotect", syserror.ENOSYS, "", nil),
+ 330: syscalls.ErrorWithEvent("pkey_alloc", syserror.ENOSYS, "", nil),
+ 331: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil),
332: syscalls.Supported("statx", Statx),
+ 333: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil),
+ 334: syscalls.ErrorWithEvent("rseq", syserror.ENOSYS, "", nil),
+
+ // Linux skips ahead to syscall 424 to sync numbers between arches.
+ 424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil),
+ 425: syscalls.ErrorWithEvent("io_uring_setup", syserror.ENOSYS, "", nil),
+ 426: syscalls.ErrorWithEvent("io_uring_enter", syserror.ENOSYS, "", nil),
+ 427: syscalls.ErrorWithEvent("io_uring_register", syserror.ENOSYS, "", nil),
+ 428: syscalls.ErrorWithEvent("open_tree", syserror.ENOSYS, "", nil),
+ 429: syscalls.ErrorWithEvent("move_mount", syserror.ENOSYS, "", nil),
+ 430: syscalls.ErrorWithEvent("fsopen", syserror.ENOSYS, "", nil),
+ 431: syscalls.ErrorWithEvent("fsconfig", syserror.ENOSYS, "", nil),
+ 432: syscalls.ErrorWithEvent("fsmount", syserror.ENOSYS, "", nil),
+ 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil),
+ 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil),
+ 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil),
},
Emulate: map[usermem.Addr]uintptr{
diff --git a/pkg/sentry/syscalls/linux/linux64_arm64.go b/pkg/sentry/syscalls/linux/linux64_arm64.go
index f82dfac31..3b584eed9 100644
--- a/pkg/sentry/syscalls/linux/linux64_arm64.go
+++ b/pkg/sentry/syscalls/linux/linux64_arm64.go
@@ -30,9 +30,9 @@ var ARM64 = &kernel.SyscallTable{
OS: abi.Linux,
Arch: arch.ARM64,
Version: kernel.Version{
- Sysname: _LINUX_SYSNAME,
- Release: _LINUX_RELEASE,
- Version: _LINUX_VERSION,
+ Sysname: LinuxSysname,
+ Release: LinuxRelease,
+ Version: LinuxVersion,
},
AuditNumber: linux.AUDIT_ARCH_AARCH64,
Table: map[uintptr]kernel.Syscall{
@@ -41,10 +41,10 @@ var ARM64 = &kernel.SyscallTable{
2: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
3: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
4: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
- 5: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 5: syscalls.PartiallySupported("setxattr", Setxattr, "Only supported for tmpfs.", nil),
6: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
7: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 8: syscalls.ErrorWithEvent("getxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 8: syscalls.PartiallySupported("getxattr", Getxattr, "Only supported for tmpfs.", nil),
9: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
10: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
11: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
@@ -224,7 +224,7 @@ var ARM64 = &kernel.SyscallTable{
189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
190: syscalls.Supported("semget", Semget),
191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil),
- 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), // TODO(b/29354920)
+ 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}),
193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil),
195: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil),
@@ -295,14 +295,33 @@ var ARM64 = &kernel.SyscallTable{
278: syscalls.Supported("getrandom", GetRandom),
279: syscalls.Supported("memfd_create", MemfdCreate),
280: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil),
- 281: syscalls.ErrorWithEvent("execveat", syserror.ENOSYS, "", []string{"gvisor.dev/issue/265"}), // TODO(b/118901836)
+ 281: syscalls.Supported("execveat", Execveat),
282: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
- 283: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(b/118904897)
+ 283: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267)
284: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
285: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil),
286: syscalls.Supported("preadv2", Preadv2),
287: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil),
+ 288: syscalls.ErrorWithEvent("pkey_mprotect", syserror.ENOSYS, "", nil),
+ 289: syscalls.ErrorWithEvent("pkey_alloc", syserror.ENOSYS, "", nil),
+ 290: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil),
291: syscalls.Supported("statx", Statx),
+ 292: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil),
+ 293: syscalls.ErrorWithEvent("rseq", syserror.ENOSYS, "", nil),
+
+ // Linux skips ahead to syscall 424 to sync numbers between arches.
+ 424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil),
+ 425: syscalls.ErrorWithEvent("io_uring_setup", syserror.ENOSYS, "", nil),
+ 426: syscalls.ErrorWithEvent("io_uring_enter", syserror.ENOSYS, "", nil),
+ 427: syscalls.ErrorWithEvent("io_uring_register", syserror.ENOSYS, "", nil),
+ 428: syscalls.ErrorWithEvent("open_tree", syserror.ENOSYS, "", nil),
+ 429: syscalls.ErrorWithEvent("move_mount", syserror.ENOSYS, "", nil),
+ 430: syscalls.ErrorWithEvent("fsopen", syserror.ENOSYS, "", nil),
+ 431: syscalls.ErrorWithEvent("fsconfig", syserror.ENOSYS, "", nil),
+ 432: syscalls.ErrorWithEvent("fsmount", syserror.ENOSYS, "", nil),
+ 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil),
+ 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil),
+ 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil),
},
Emulate: map[usermem.Addr]uintptr{},
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index b9a8e3e21..9bc2445a5 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -169,10 +169,14 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint
if dirPath {
return syserror.ENOTDIR
}
- if flags&linux.O_TRUNC != 0 {
- if err := d.Inode.Truncate(t, d, 0); err != nil {
- return err
- }
+ }
+
+ // Truncate is called when O_TRUNC is specified for any kind of
+ // existing Dirent. Behavior is delegated to the entry's Truncate
+ // implementation.
+ if flags&linux.O_TRUNC != 0 {
+ if err := d.Inode.Truncate(t, d, 0); err != nil {
+ return err
}
}
@@ -396,7 +400,9 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l
return err
}
- // Should we truncate the file?
+ // Truncate is called when O_TRUNC is specified for any kind of
+ // existing Dirent. Behavior is delegated to the entry's Truncate
+ // implementation.
if flags&linux.O_TRUNC != 0 {
if err := found.Inode.Truncate(t, found, 0); err != nil {
return err
@@ -834,23 +840,40 @@ func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
return uintptr(newfd), nil, nil
}
-func fGetOwn(t *kernel.Task, file *fs.File) int32 {
+func fGetOwnEx(t *kernel.Task, file *fs.File) linux.FOwnerEx {
ma := file.Async(nil)
if ma == nil {
- return 0
+ return linux.FOwnerEx{}
}
a := ma.(*fasync.FileAsync)
ot, otg, opg := a.Owner()
switch {
case ot != nil:
- return int32(t.PIDNamespace().IDOfTask(ot))
+ return linux.FOwnerEx{
+ Type: linux.F_OWNER_TID,
+ PID: int32(t.PIDNamespace().IDOfTask(ot)),
+ }
case otg != nil:
- return int32(t.PIDNamespace().IDOfThreadGroup(otg))
+ return linux.FOwnerEx{
+ Type: linux.F_OWNER_PID,
+ PID: int32(t.PIDNamespace().IDOfThreadGroup(otg)),
+ }
case opg != nil:
- return int32(-t.PIDNamespace().IDOfProcessGroup(opg))
+ return linux.FOwnerEx{
+ Type: linux.F_OWNER_PGRP,
+ PID: int32(t.PIDNamespace().IDOfProcessGroup(opg)),
+ }
default:
- return 0
+ return linux.FOwnerEx{}
+ }
+}
+
+func fGetOwn(t *kernel.Task, file *fs.File) int32 {
+ owner := fGetOwnEx(t, file)
+ if owner.Type == linux.F_OWNER_PGRP {
+ return -owner.PID
}
+ return owner.PID
}
// fSetOwn sets the file's owner with the semantics of F_SETOWN in Linux.
@@ -895,11 +918,13 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
t.FDTable().SetFlags(fd, kernel.FDFlags{
CloseOnExec: flags&linux.FD_CLOEXEC != 0,
})
+ return 0, nil, nil
case linux.F_GETFL:
return uintptr(file.Flags().ToLinux()), nil, nil
case linux.F_SETFL:
flags := uint(args[2].Uint())
file.SetFlags(linuxToFlags(flags).Settable())
+ return 0, nil, nil
case linux.F_SETLK, linux.F_SETLKW:
// In Linux the file system can choose to provide lock operations for an inode.
// Normally pipe and socket types lack lock operations. We diverge and use a heavy
@@ -1002,6 +1027,44 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.F_SETOWN:
fSetOwn(t, file, args[2].Int())
return 0, nil, nil
+ case linux.F_GETOWN_EX:
+ addr := args[2].Pointer()
+ owner := fGetOwnEx(t, file)
+ _, err := t.CopyOut(addr, &owner)
+ return 0, nil, err
+ case linux.F_SETOWN_EX:
+ addr := args[2].Pointer()
+ var owner linux.FOwnerEx
+ n, err := t.CopyIn(addr, &owner)
+ if err != nil {
+ return 0, nil, err
+ }
+ a := file.Async(fasync.New).(*fasync.FileAsync)
+ switch owner.Type {
+ case linux.F_OWNER_TID:
+ task := t.PIDNamespace().TaskWithID(kernel.ThreadID(owner.PID))
+ if task == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ a.SetOwnerTask(t, task)
+ return uintptr(n), nil, nil
+ case linux.F_OWNER_PID:
+ tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(owner.PID))
+ if tg == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ a.SetOwnerThreadGroup(t, tg)
+ return uintptr(n), nil, nil
+ case linux.F_OWNER_PGRP:
+ pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(owner.PID))
+ if pg == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ a.SetOwnerProcessGroup(t, pg)
+ return uintptr(n), nil, nil
+ default:
+ return 0, nil, syserror.EINVAL
+ }
case linux.F_GET_SEALS:
val, err := tmpfs.GetSeals(file.Dirent.Inode)
return uintptr(val), nil, err
@@ -1029,7 +1092,6 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// Everything else is not yet supported.
return 0, nil, syserror.EINVAL
}
- return 0, nil, nil
}
const (
@@ -1483,6 +1545,8 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if fs.IsDir(d.Inode.StableAttr) {
return syserror.EISDIR
}
+ // In contrast to open(O_TRUNC), truncate(2) is only valid for file
+ // types.
if !fs.IsFile(d.Inode.StableAttr) {
return syserror.EINVAL
}
@@ -1521,7 +1585,8 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, syserror.EINVAL
}
- // Note that this is different from truncate(2) above, where a
+ // In contrast to open(O_TRUNC), truncate(2) is only valid for file
+ // types. Note that this is different from truncate(2) above, where a
// directory returns EISDIR.
if !fs.IsFile(file.Dirent.Inode.StableAttr) {
return 0, nil, syserror.EINVAL
diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go
index b9bd25464..bde17a767 100644
--- a/pkg/sentry/syscalls/linux/sys_futex.go
+++ b/pkg/sentry/syscalls/linux/sys_futex.go
@@ -226,6 +226,11 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if mask == 0 {
return 0, nil, syserror.EINVAL
}
+ if val <= 0 {
+ // The Linux kernel wakes one waiter even if val is
+ // non-positive.
+ val = 1
+ }
n, err := t.Futex().Wake(t, addr, private, mask, val)
return uintptr(n), nil, err
@@ -242,6 +247,11 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.FUTEX_WAKE_OP:
op := uint32(val3)
+ if val <= 0 {
+ // The Linux kernel wakes one waiter even if val is
+ // non-positive.
+ val = 1
+ }
n, err := t.Futex().WakeOp(t, addr, naddr, private, val, nreq, op)
return uintptr(n), nil, err
diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go
index 7a13beac2..2b2df989a 100644
--- a/pkg/sentry/syscalls/linux/sys_poll.go
+++ b/pkg/sentry/syscalls/linux/sys_poll.go
@@ -197,53 +197,51 @@ func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration)
return remainingTimeout, n, err
}
+// CopyInFDSet copies an fd set from select(2)/pselect(2).
+func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) {
+ set := make([]byte, nBytes)
+
+ if addr != 0 {
+ if _, err := t.CopyIn(addr, &set); err != nil {
+ return nil, err
+ }
+ // If we only use part of the last byte, mask out the extraneous bits.
+ //
+ // N.B. This only works on little-endian architectures.
+ if nBitsInLastPartialByte != 0 {
+ set[nBytes-1] &^= byte(0xff) << nBitsInLastPartialByte
+ }
+ }
+ return set, nil
+}
+
func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Addr, timeout time.Duration) (uintptr, error) {
if nfds < 0 || nfds > fileCap {
return 0, syserror.EINVAL
}
- // Capture all the provided input vectors.
- //
- // N.B. This only works on little-endian architectures.
- byteCount := (nfds + 7) / 8
-
- bitsInLastPartialByte := uint(nfds % 8)
- r := make([]byte, byteCount)
- w := make([]byte, byteCount)
- e := make([]byte, byteCount)
+ // Calculate the size of the fd sets (one bit per fd).
+ nBytes := (nfds + 7) / 8
+ nBitsInLastPartialByte := nfds % 8
- if readFDs != 0 {
- if _, err := t.CopyIn(readFDs, &r); err != nil {
- return 0, err
- }
- // Mask out bits above nfds.
- if bitsInLastPartialByte != 0 {
- r[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte
- }
+ // Capture all the provided input vectors.
+ r, err := CopyInFDSet(t, readFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
}
-
- if writeFDs != 0 {
- if _, err := t.CopyIn(writeFDs, &w); err != nil {
- return 0, err
- }
- if bitsInLastPartialByte != 0 {
- w[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte
- }
+ w, err := CopyInFDSet(t, writeFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
}
-
- if exceptFDs != 0 {
- if _, err := t.CopyIn(exceptFDs, &e); err != nil {
- return 0, err
- }
- if bitsInLastPartialByte != 0 {
- e[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte
- }
+ e, err := CopyInFDSet(t, exceptFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
}
// Count how many FDs are actually being requested so that we can build
// a PollFD array.
fdCount := 0
- for i := 0; i < byteCount; i++ {
+ for i := 0; i < nBytes; i++ {
v := r[i] | w[i] | e[i]
for v != 0 {
v &= (v - 1)
@@ -254,7 +252,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add
// Build the PollFD array.
pfd := make([]linux.PollFD, 0, fdCount)
var fd int32
- for i := 0; i < byteCount; i++ {
+ for i := 0; i < nBytes; i++ {
rV, wV, eV := r[i], w[i], e[i]
v := rV | wV | eV
m := byte(1)
@@ -295,8 +293,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add
}
// Do the syscall, then count the number of bits set.
- _, _, err := pollBlock(t, pfd, timeout)
- if err != nil {
+ if _, _, err = pollBlock(t, pfd, timeout); err != nil {
return 0, syserror.ConvertIntr(err, syserror.EINTR)
}
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index b5a72ce63..4b5aafcc0 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -447,16 +447,13 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, syserror.ENOTSOCK
}
- // Read the length if present. Reject negative values.
+ // Read the length. Reject negative values.
optLen := int32(0)
- if optLenAddr != 0 {
- if _, err := t.CopyIn(optLenAddr, &optLen); err != nil {
- return 0, nil, err
- }
-
- if optLen < 0 {
- return 0, nil, syserror.EINVAL
- }
+ if _, err := t.CopyIn(optLenAddr, &optLen); err != nil {
+ return 0, nil, err
+ }
+ if optLen < 0 {
+ return 0, nil, syserror.EINVAL
}
// Call syscall implementation then copy both value and value len out.
@@ -465,11 +462,9 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, e.ToError()
}
- if optLenAddr != 0 {
- vLen := int32(binary.Size(v))
- if _, err := t.CopyOut(optLenAddr, vLen); err != nil {
- return 0, nil, err
- }
+ vLen := int32(binary.Size(v))
+ if _, err := t.CopyOut(optLenAddr, vLen); err != nil {
+ return 0, nil, err
}
if v != nil {
@@ -769,7 +764,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
}
if !cms.Unix.Empty() {
mflags |= linux.MSG_CTRUNC
- cms.Unix.Release()
+ cms.Release()
}
if int(msg.Flags) != mflags {
@@ -789,24 +784,16 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
if e != nil {
return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
}
- defer cms.Unix.Release()
+ defer cms.Release()
controlData := make([]byte, 0, msg.ControlLen)
+ controlData = control.PackControlMessages(t, cms, controlData)
if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() {
creds, _ := cms.Unix.Credentials.(control.SCMCredentials)
controlData, mflags = control.PackCredentials(t, creds, controlData, mflags)
}
- if cms.IP.HasTimestamp {
- controlData = control.PackTimestamp(t, cms.IP.Timestamp, controlData)
- }
-
- if cms.IP.HasInq {
- // In Linux, TCP_CM_INQ is added after SO_TIMESTAMP.
- controlData = control.PackInq(t, cms.IP.Inq, controlData)
- }
-
if cms.Unix.Rights != nil {
controlData, mflags = control.PackRights(t, cms.Unix.Rights.(control.SCMRights), flags&linux.MSG_CMSG_CLOEXEC != 0, controlData, mflags)
}
@@ -882,7 +869,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flag
}
n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0)
- cm.Unix.Release()
+ cm.Release()
if e != nil {
return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
}
@@ -1065,7 +1052,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme
}
// Call the syscall implementation.
- n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: controlMessages})
+ n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages)
err = handleIOError(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file)
if err != nil {
controlMessages.Release()
diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go
index 8ab7ffa25..4115116ff 100644
--- a/pkg/sentry/syscalls/linux/sys_thread.go
+++ b/pkg/sentry/syscalls/linux/sys_thread.go
@@ -15,12 +15,15 @@
package linux
import (
+ "path"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
+ "gvisor.dev/gvisor/pkg/sentry/loader"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -67,8 +70,22 @@ func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
argvAddr := args[1].Pointer()
envvAddr := args[2].Pointer()
- // Extract our arguments.
- filename, err := t.CopyInString(filenameAddr, linux.PATH_MAX)
+ return execveat(t, linux.AT_FDCWD, filenameAddr, argvAddr, envvAddr, 0)
+}
+
+// Execveat implements linux syscall execveat(2).
+func Execveat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirFD := args[0].Int()
+ pathnameAddr := args[1].Pointer()
+ argvAddr := args[2].Pointer()
+ envvAddr := args[3].Pointer()
+ flags := args[4].Int()
+
+ return execveat(t, dirFD, pathnameAddr, argvAddr, envvAddr, flags)
+}
+
+func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr usermem.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) {
+ pathname, err := t.CopyInString(pathnameAddr, linux.PATH_MAX)
if err != nil {
return 0, nil, err
}
@@ -89,14 +106,66 @@ func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
}
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ atEmptyPath := flags&linux.AT_EMPTY_PATH != 0
+ if !atEmptyPath && len(pathname) == 0 {
+ return 0, nil, syserror.ENOENT
+ }
+ resolveFinal := flags&linux.AT_SYMLINK_NOFOLLOW == 0
+
root := t.FSContext().RootDirectory()
defer root.DecRef()
- wd := t.FSContext().WorkingDirectory()
- defer wd.DecRef()
+
+ var wd *fs.Dirent
+ var executable *fs.File
+ var closeOnExec bool
+ if dirFD == linux.AT_FDCWD || path.IsAbs(pathname) {
+ // Even if the pathname is absolute, we may still need the wd
+ // for interpreter scripts if the path of the interpreter is
+ // relative.
+ wd = t.FSContext().WorkingDirectory()
+ } else {
+ // Need to extract the given FD.
+ f, fdFlags := t.FDTable().Get(dirFD)
+ if f == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer f.DecRef()
+ closeOnExec = fdFlags.CloseOnExec
+
+ if atEmptyPath && len(pathname) == 0 {
+ executable = f
+ } else {
+ wd = f.Dirent
+ wd.IncRef()
+ if !fs.IsDir(wd.Inode.StableAttr) {
+ return 0, nil, syserror.ENOTDIR
+ }
+ }
+ }
+ if wd != nil {
+ defer wd.DecRef()
+ }
// Load the new TaskContext.
- maxTraversals := uint(linux.MaxSymlinkTraversals)
- tc, se := t.Kernel().LoadTaskImage(t, t.MountNamespace(), root, wd, &maxTraversals, filename, nil, argv, envv, t.Arch().FeatureSet())
+ remainingTraversals := uint(linux.MaxSymlinkTraversals)
+ loadArgs := loader.LoadArgs{
+ Mounts: t.MountNamespace(),
+ Root: root,
+ WorkingDirectory: wd,
+ RemainingTraversals: &remainingTraversals,
+ ResolveFinal: resolveFinal,
+ Filename: pathname,
+ File: executable,
+ CloseOnExec: closeOnExec,
+ Argv: argv,
+ Envv: envv,
+ Features: t.Arch().FeatureSet(),
+ }
+
+ tc, se := t.Kernel().LoadTaskImage(t, loadArgs)
if se != nil {
return 0, nil, se.ToError()
}
diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go
index 27cd2c336..ad4b67806 100644
--- a/pkg/sentry/syscalls/linux/sys_write.go
+++ b/pkg/sentry/syscalls/linux/sys_write.go
@@ -191,7 +191,6 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
}
// Pwritev2 implements linux syscall pwritev2(2).
-// TODO(b/120161091): Implement O_SYNC and D_SYNC functionality.
func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
// While the syscall is
// pwritev2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags)
diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go
new file mode 100644
index 000000000..97d9a65ea
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_xattr.go
@@ -0,0 +1,169 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Getxattr implements linux syscall getxattr(2).
+func Getxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+
+ path, dirPath, err := copyInPath(t, pathAddr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ valueLen := 0
+ err = fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ value, err := getxattr(t, d, dirPath, nameAddr)
+ if err != nil {
+ return err
+ }
+
+ valueLen = len(value)
+ if size == 0 {
+ return nil
+ }
+ if size > linux.XATTR_SIZE_MAX {
+ size = linux.XATTR_SIZE_MAX
+ }
+ if valueLen > int(size) {
+ return syserror.ERANGE
+ }
+
+ _, err = t.CopyOutBytes(valueAddr, []byte(value))
+ return err
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(valueLen), nil, nil
+}
+
+// getxattr implements getxattr from the given *fs.Dirent.
+func getxattr(t *kernel.Task, d *fs.Dirent, dirPath bool, nameAddr usermem.Addr) (string, error) {
+ if dirPath && !fs.IsDir(d.Inode.StableAttr) {
+ return "", syserror.ENOTDIR
+ }
+
+ if err := checkXattrPermissions(t, d.Inode, fs.PermMask{Read: true}); err != nil {
+ return "", err
+ }
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return "", err
+ }
+
+ if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
+ return "", syserror.EOPNOTSUPP
+ }
+
+ return d.Inode.Getxattr(name)
+}
+
+// Setxattr implements linux syscall setxattr(2).
+func Setxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+ flags := args[4].Uint()
+
+ path, dirPath, err := copyInPath(t, pathAddr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ return setxattr(t, d, dirPath, nameAddr, valueAddr, size, flags)
+ })
+}
+
+// setxattr implements setxattr from the given *fs.Dirent.
+func setxattr(t *kernel.Task, d *fs.Dirent, dirPath bool, nameAddr, valueAddr usermem.Addr, size uint, flags uint32) error {
+ if dirPath && !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ if err := checkXattrPermissions(t, d.Inode, fs.PermMask{Write: true}); err != nil {
+ return err
+ }
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return err
+ }
+
+ if size > linux.XATTR_SIZE_MAX {
+ return syserror.E2BIG
+ }
+ buf := make([]byte, size)
+ if _, err = t.CopyInBytes(valueAddr, buf); err != nil {
+ return err
+ }
+ value := string(buf)
+
+ if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
+ return syserror.EOPNOTSUPP
+ }
+
+ return d.Inode.Setxattr(name, value)
+}
+
+func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) {
+ name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1)
+ if err != nil {
+ if err == syserror.ENAMETOOLONG {
+ return "", syserror.ERANGE
+ }
+ return "", err
+ }
+ if len(name) == 0 {
+ return "", syserror.ERANGE
+ }
+ return name, nil
+}
+
+func checkXattrPermissions(t *kernel.Task, i *fs.Inode, perms fs.PermMask) error {
+ // Restrict xattrs to regular files and directories.
+ //
+ // In Linux, this restriction technically only applies to xattrs in the
+ // "user.*" namespace, but we don't allow any other xattr prefixes anyway.
+ if !fs.IsRegular(i.StableAttr) && !fs.IsDir(i.StableAttr) {
+ if perms.Write {
+ return syserror.EPERM
+ }
+ return syserror.ENODATA
+ }
+
+ return i.CheckPermission(t, perms)
+}
diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD
index d3a4cd943..18e212dff 100644
--- a/pkg/sentry/time/BUILD
+++ b/pkg/sentry/time/BUILD
@@ -9,7 +9,7 @@ go_template_instance(
out = "seqatomic_parameters_unsafe.go",
package = "time",
suffix = "Parameters",
- template = "//third_party/gvsync:generic_seqatomic",
+ template = "//pkg/syncutil:generic_seqatomic",
types = {
"Value": "Parameters",
},
@@ -36,8 +36,8 @@ go_library(
deps = [
"//pkg/log",
"//pkg/metric",
+ "//pkg/syncutil",
"//pkg/syserror",
- "//third_party/gvsync",
],
)
diff --git a/pkg/sentry/usermem/bytes_io.go b/pkg/sentry/usermem/bytes_io.go
index 8d88396ba..7898851b3 100644
--- a/pkg/sentry/usermem/bytes_io.go
+++ b/pkg/sentry/usermem/bytes_io.go
@@ -102,19 +102,34 @@ func (b *BytesIO) rangeCheck(addr Addr, length int) (int, error) {
}
func (b *BytesIO) blocksFromAddrRanges(ars AddrRangeSeq) (safemem.BlockSeq, error) {
- blocks := make([]safemem.Block, 0, ars.NumRanges())
- for !ars.IsEmpty() {
- ar := ars.Head()
- n, err := b.rangeCheck(ar.Start, int(ar.Length()))
- if n != 0 {
- blocks = append(blocks, safemem.BlockFromSafeSlice(b.Bytes[int(ar.Start):int(ar.Start)+n]))
+ switch ars.NumRanges() {
+ case 0:
+ return safemem.BlockSeq{}, nil
+ case 1:
+ block, err := b.blockFromAddrRange(ars.Head())
+ return safemem.BlockSeqOf(block), err
+ default:
+ blocks := make([]safemem.Block, 0, ars.NumRanges())
+ for !ars.IsEmpty() {
+ block, err := b.blockFromAddrRange(ars.Head())
+ if block.Len() != 0 {
+ blocks = append(blocks, block)
+ }
+ if err != nil {
+ return safemem.BlockSeqFromSlice(blocks), err
+ }
+ ars = ars.Tail()
}
- if err != nil {
- return safemem.BlockSeqFromSlice(blocks), err
- }
- ars = ars.Tail()
+ return safemem.BlockSeqFromSlice(blocks), nil
+ }
+}
+
+func (b *BytesIO) blockFromAddrRange(ar AddrRange) (safemem.Block, error) {
+ n, err := b.rangeCheck(ar.Start, int(ar.Length()))
+ if n == 0 {
+ return safemem.Block{}, err
}
- return safemem.BlockSeqFromSlice(blocks), nil
+ return safemem.BlockFromSafeSlice(b.Bytes[int(ar.Start) : int(ar.Start)+n]), err
}
// BytesIOSequence returns an IOSequence representing the given byte slice.
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index eff4b44f6..e3e554b88 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -12,13 +12,14 @@ go_library(
"file_description.go",
"file_description_impl_util.go",
"filesystem.go",
+ "filesystem_impl_util.go",
"filesystem_type.go",
"mount.go",
"mount_unsafe.go",
"options.go",
+ "pathname.go",
"permissions.go",
"resolving_path.go",
- "syscalls.go",
"testutil.go",
"vfs.go",
],
@@ -32,9 +33,9 @@ go_library(
"//pkg/sentry/kernel/auth",
"//pkg/sentry/memmap",
"//pkg/sentry/usermem",
+ "//pkg/syncutil",
"//pkg/syserror",
"//pkg/waiter",
- "//third_party/gvsync",
],
)
diff --git a/pkg/sentry/vfs/README.md b/pkg/sentry/vfs/README.md
index 7847854bc..9aa133bcb 100644
--- a/pkg/sentry/vfs/README.md
+++ b/pkg/sentry/vfs/README.md
@@ -39,8 +39,8 @@ Mount references are held by:
- Mount: Each referenced Mount holds a reference on its parent, which is the
mount containing its mount point.
-- VirtualFilesystem: A reference is held on all Mounts that are attached
- (reachable by Mount traversal).
+- VirtualFilesystem: A reference is held on each Mount that has not been
+ umounted.
MountNamespace and FileDescription references are held by users of VFS. The
expectation is that each `kernel.Task` holds a reference on its corresponding
diff --git a/pkg/sentry/vfs/context.go b/pkg/sentry/vfs/context.go
index 32cf9151b..705194ebc 100644
--- a/pkg/sentry/vfs/context.go
+++ b/pkg/sentry/vfs/context.go
@@ -24,6 +24,9 @@ type contextID int
const (
// CtxMountNamespace is a Context.Value key for a MountNamespace.
CtxMountNamespace contextID = iota
+
+ // CtxRoot is a Context.Value key for a VFS root.
+ CtxRoot
)
// MountNamespaceFromContext returns the MountNamespace used by ctx. It does
@@ -35,3 +38,13 @@ func MountNamespaceFromContext(ctx context.Context) *MountNamespace {
}
return nil
}
+
+// RootFromContext returns the VFS root used by ctx. It takes a reference on
+// the returned VirtualDentry. If ctx does not have a specific VFS root,
+// RootFromContext returns a zero-value VirtualDentry.
+func RootFromContext(ctx context.Context) VirtualDentry {
+ if v := ctx.Value(CtxRoot); v != nil {
+ return v.(VirtualDentry)
+ }
+ return VirtualDentry{}
+}
diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go
index 45912fc58..6209eb053 100644
--- a/pkg/sentry/vfs/dentry.go
+++ b/pkg/sentry/vfs/dentry.go
@@ -16,6 +16,7 @@ package vfs
import (
"fmt"
+ "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/syserror"
@@ -50,7 +51,7 @@ import (
// and not inodes. Furthermore, when parties outside the scope of VFS can
// rename inodes on such filesystems, VFS generally cannot "follow" the rename,
// both due to synchronization issues and because it may not even be able to
-// name the destination path; this implies that it would in fact be *incorrect*
+// name the destination path; this implies that it would in fact be incorrect
// for Dentries to be associated with inodes on such filesystems. Consequently,
// operations that are inode operations in Linux are FilesystemImpl methods
// and/or FileDescriptionImpl methods in gVisor's VFS. Filesystems that do
@@ -87,6 +88,9 @@ type Dentry struct {
// children are child Dentries.
children map[string]*Dentry
+ // mu synchronizes disowning and mounting over this Dentry.
+ mu sync.Mutex
+
// impl is the DentryImpl associated with this Dentry. impl is immutable.
// This should be the last field in Dentry.
impl DentryImpl
@@ -114,7 +118,7 @@ func (d *Dentry) Impl() DentryImpl {
type DentryImpl interface {
// IncRef increments the Dentry's reference count. A Dentry with a non-zero
// reference count must remain coherent with the state of the filesystem.
- IncRef(fs *Filesystem)
+ IncRef()
// TryIncRef increments the Dentry's reference count and returns true. If
// the Dentry's reference count is zero, TryIncRef may do nothing and
@@ -122,10 +126,10 @@ type DentryImpl interface {
// guarantee that the Dentry is coherent with the state of the filesystem.)
//
// TryIncRef does not require that a reference is held on the Dentry.
- TryIncRef(fs *Filesystem) bool
+ TryIncRef() bool
// DecRef decrements the Dentry's reference count.
- DecRef(fs *Filesystem)
+ DecRef()
}
// IsDisowned returns true if d is disowned.
@@ -142,16 +146,20 @@ func (d *Dentry) isMounted() bool {
return atomic.LoadUint32(&d.mounts) != 0
}
-func (d *Dentry) incRef(fs *Filesystem) {
- d.impl.IncRef(fs)
+// IncRef increments d's reference count.
+func (d *Dentry) IncRef() {
+ d.impl.IncRef()
}
-func (d *Dentry) tryIncRef(fs *Filesystem) bool {
- return d.impl.TryIncRef(fs)
+// TryIncRef increments d's reference count and returns true. If d's reference
+// count is zero, TryIncRef may instead do nothing and return false.
+func (d *Dentry) TryIncRef() bool {
+ return d.impl.TryIncRef()
}
-func (d *Dentry) decRef(fs *Filesystem) {
- d.impl.DecRef(fs)
+// DecRef decrements d's reference count.
+func (d *Dentry) DecRef() {
+ d.impl.DecRef()
}
// These functions are exported so that filesystem implementations can use
@@ -191,6 +199,18 @@ func (d *Dentry) HasChildren() bool {
return len(d.children) != 0
}
+// Children returns a map containing all of d's children.
+func (d *Dentry) Children() map[string]*Dentry {
+ if !d.HasChildren() {
+ return nil
+ }
+ m := make(map[string]*Dentry)
+ for name, child := range d.children {
+ m[name] = child
+ }
+ return m
+}
+
// InsertChild makes child a child of d with the given name.
//
// InsertChild is a mutator of d and child.
@@ -228,36 +248,48 @@ func (vfs *VirtualFilesystem) PrepareDeleteDentry(mntns *MountNamespace, d *Dent
panic("d is already disowned")
}
}
- vfs.mountMu.RLock()
- if _, ok := mntns.mountpoints[d]; ok {
- vfs.mountMu.RUnlock()
+ vfs.mountMu.Lock()
+ if mntns.mountpoints[d] != 0 {
+ vfs.mountMu.Unlock()
return syserror.EBUSY
}
- // Return with vfs.mountMu locked, which will be unlocked by
- // AbortDeleteDentry or CommitDeleteDentry.
+ d.mu.Lock()
+ vfs.mountMu.Unlock()
+ // Return with d.mu locked to block attempts to mount over it; it will be
+ // unlocked by AbortDeleteDentry or CommitDeleteDentry.
return nil
}
// AbortDeleteDentry must be called after PrepareDeleteDentry if the deletion
// fails.
-func (vfs *VirtualFilesystem) AbortDeleteDentry() {
- vfs.mountMu.RUnlock()
+func (vfs *VirtualFilesystem) AbortDeleteDentry(d *Dentry) {
+ d.mu.Unlock()
}
// CommitDeleteDentry must be called after the file represented by d is
// deleted, and causes d to become disowned.
//
+// CommitDeleteDentry is a mutator of d and d.Parent().
+//
// Preconditions: PrepareDeleteDentry was previously called on d.
func (vfs *VirtualFilesystem) CommitDeleteDentry(d *Dentry) {
- delete(d.parent.children, d.name)
+ if d.parent != nil {
+ delete(d.parent.children, d.name)
+ }
d.setDisowned()
- // TODO: lazily unmount mounts at d
- vfs.mountMu.RUnlock()
+ d.mu.Unlock()
+ if d.isMounted() {
+ vfs.forgetDisownedMountpoint(d)
+ }
}
// DeleteDentry combines PrepareDeleteDentry and CommitDeleteDentry, as
// appropriate for in-memory filesystems that don't need to ensure that some
// external state change succeeds before committing the deletion.
+//
+// DeleteDentry is a mutator of d and d.Parent().
+//
+// Preconditions: d is a child Dentry.
func (vfs *VirtualFilesystem) DeleteDentry(mntns *MountNamespace, d *Dentry) error {
if err := vfs.PrepareDeleteDentry(mntns, d); err != nil {
return err
@@ -266,6 +298,27 @@ func (vfs *VirtualFilesystem) DeleteDentry(mntns *MountNamespace, d *Dentry) err
return nil
}
+// ForceDeleteDentry causes d to become disowned. It should only be used in
+// cases where VFS has no ability to stop the deletion (e.g. d represents the
+// local state of a file on a remote filesystem on which the file has already
+// been deleted).
+//
+// ForceDeleteDentry is a mutator of d and d.Parent().
+//
+// Preconditions: d is a child Dentry.
+func (vfs *VirtualFilesystem) ForceDeleteDentry(d *Dentry) {
+ if checkInvariants {
+ if d.parent == nil {
+ panic("d is independent")
+ }
+ if d.IsDisowned() {
+ panic("d is already disowned")
+ }
+ }
+ d.mu.Lock()
+ vfs.CommitDeleteDentry(d)
+}
+
// PrepareRenameDentry must be called before attempting to rename the file
// represented by from. If to is not nil, it represents the file that will be
// replaced or exchanged by the rename. If PrepareRenameDentry succeeds, the
@@ -291,18 +344,21 @@ func (vfs *VirtualFilesystem) PrepareRenameDentry(mntns *MountNamespace, from, t
}
}
}
- vfs.mountMu.RLock()
- if _, ok := mntns.mountpoints[from]; ok {
- vfs.mountMu.RUnlock()
+ vfs.mountMu.Lock()
+ if mntns.mountpoints[from] != 0 {
+ vfs.mountMu.Unlock()
return syserror.EBUSY
}
if to != nil {
- if _, ok := mntns.mountpoints[to]; ok {
- vfs.mountMu.RUnlock()
+ if mntns.mountpoints[to] != 0 {
+ vfs.mountMu.Unlock()
return syserror.EBUSY
}
+ to.mu.Lock()
}
- // Return with vfs.mountMu locked, which will be unlocked by
+ from.mu.Lock()
+ vfs.mountMu.Unlock()
+ // Return with from.mu and to.mu locked, which will be unlocked by
// AbortRenameDentry, CommitRenameReplaceDentry, or
// CommitRenameExchangeDentry.
return nil
@@ -310,38 +366,76 @@ func (vfs *VirtualFilesystem) PrepareRenameDentry(mntns *MountNamespace, from, t
// AbortRenameDentry must be called after PrepareRenameDentry if the rename
// fails.
-func (vfs *VirtualFilesystem) AbortRenameDentry() {
- vfs.mountMu.RUnlock()
+func (vfs *VirtualFilesystem) AbortRenameDentry(from, to *Dentry) {
+ from.mu.Unlock()
+ if to != nil {
+ to.mu.Unlock()
+ }
}
// CommitRenameReplaceDentry must be called after the file represented by from
// is renamed without RENAME_EXCHANGE. If to is not nil, it represents the file
// that was replaced by from.
//
+// CommitRenameReplaceDentry is a mutator of from, to, from.Parent(), and
+// to.Parent().
+//
// Preconditions: PrepareRenameDentry was previously called on from and to.
// newParent.Child(newName) == to.
func (vfs *VirtualFilesystem) CommitRenameReplaceDentry(from, newParent *Dentry, newName string, to *Dentry) {
- if to != nil {
- to.setDisowned()
- // TODO: lazily unmount mounts at d
- }
if newParent.children == nil {
newParent.children = make(map[string]*Dentry)
}
newParent.children[newName] = from
from.parent = newParent
from.name = newName
- vfs.mountMu.RUnlock()
+ from.mu.Unlock()
+ if to != nil {
+ to.setDisowned()
+ to.mu.Unlock()
+ if to.isMounted() {
+ vfs.forgetDisownedMountpoint(to)
+ }
+ }
}
// CommitRenameExchangeDentry must be called after the files represented by
// from and to are exchanged by rename(RENAME_EXCHANGE).
//
+// CommitRenameExchangeDentry is a mutator of from, to, from.Parent(), and
+// to.Parent().
+//
// Preconditions: PrepareRenameDentry was previously called on from and to.
func (vfs *VirtualFilesystem) CommitRenameExchangeDentry(from, to *Dentry) {
from.parent, to.parent = to.parent, from.parent
from.name, to.name = to.name, from.name
from.parent.children[from.name] = from
to.parent.children[to.name] = to
- vfs.mountMu.RUnlock()
+ from.mu.Unlock()
+ to.mu.Unlock()
+}
+
+// forgetDisownedMountpoint is called when a mount point is deleted to umount
+// all mounts using it in all other mount namespaces.
+//
+// forgetDisownedMountpoint is analogous to Linux's
+// fs/namespace.c:__detach_mounts().
+func (vfs *VirtualFilesystem) forgetDisownedMountpoint(d *Dentry) {
+ var (
+ vdsToDecRef []VirtualDentry
+ mountsToDecRef []*Mount
+ )
+ vfs.mountMu.Lock()
+ vfs.mounts.seq.BeginWrite()
+ for mnt := range vfs.mountpoints[d] {
+ vdsToDecRef, mountsToDecRef = vfs.umountRecursiveLocked(mnt, &umountRecursiveOptions{}, vdsToDecRef, mountsToDecRef)
+ }
+ vfs.mounts.seq.EndWrite()
+ vfs.mountMu.Unlock()
+ for _, vd := range vdsToDecRef {
+ vd.DecRef()
+ }
+ for _, mnt := range mountsToDecRef {
+ mnt.DecRef()
+ }
}
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 3a9665800..df03886c3 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -20,8 +20,10 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -38,43 +40,64 @@ type FileDescription struct {
// operations.
refs int64
+ // statusFlags contains status flags, "initialized by open(2) and possibly
+ // modified by fcntl()" - fcntl(2). statusFlags is accessed using atomic
+ // memory operations.
+ statusFlags uint32
+
// vd is the filesystem location at which this FileDescription was opened.
// A reference is held on vd. vd is immutable.
vd VirtualDentry
+ opts FileDescriptionOptions
+
// impl is the FileDescriptionImpl associated with this Filesystem. impl is
// immutable. This should be the last field in FileDescription.
impl FileDescriptionImpl
}
-// Init must be called before first use of fd. It takes references on mnt and
-// d.
-func (fd *FileDescription) Init(impl FileDescriptionImpl, mnt *Mount, d *Dentry) {
+// FileDescriptionOptions contains options to FileDescription.Init().
+type FileDescriptionOptions struct {
+ // If AllowDirectIO is true, allow O_DIRECT to be set on the file. This is
+ // usually only the case if O_DIRECT would actually have an effect.
+ AllowDirectIO bool
+}
+
+// Init must be called before first use of fd. It takes ownership of references
+// on mnt and d held by the caller. statusFlags is the initial file description
+// status flags, which is usually the full set of flags passed to open(2).
+func (fd *FileDescription) Init(impl FileDescriptionImpl, statusFlags uint32, mnt *Mount, d *Dentry, opts *FileDescriptionOptions) {
fd.refs = 1
+ fd.statusFlags = statusFlags | linux.O_LARGEFILE
fd.vd = VirtualDentry{
mount: mnt,
dentry: d,
}
- fd.vd.IncRef()
+ fd.opts = *opts
fd.impl = impl
}
-// Impl returns the FileDescriptionImpl associated with fd.
-func (fd *FileDescription) Impl() FileDescriptionImpl {
- return fd.impl
-}
-
-// VirtualDentry returns the location at which fd was opened. It does not take
-// a reference on the returned VirtualDentry.
-func (fd *FileDescription) VirtualDentry() VirtualDentry {
- return fd.vd
-}
-
// IncRef increments fd's reference count.
func (fd *FileDescription) IncRef() {
atomic.AddInt64(&fd.refs, 1)
}
+// TryIncRef increments fd's reference count and returns true. If fd's
+// reference count is already zero, TryIncRef does nothing and returns false.
+//
+// TryIncRef does not require that a reference is held on fd.
+func (fd *FileDescription) TryIncRef() bool {
+ for {
+ refs := atomic.LoadInt64(&fd.refs)
+ if refs <= 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt64(&fd.refs, refs, refs+1) {
+ return true
+ }
+ }
+}
+
// DecRef decrements fd's reference count.
func (fd *FileDescription) DecRef() {
if refs := atomic.AddInt64(&fd.refs, -1); refs == 0 {
@@ -85,6 +108,82 @@ func (fd *FileDescription) DecRef() {
}
}
+// Mount returns the mount on which fd was opened. It does not take a reference
+// on the returned Mount.
+func (fd *FileDescription) Mount() *Mount {
+ return fd.vd.mount
+}
+
+// Dentry returns the dentry at which fd was opened. It does not take a
+// reference on the returned Dentry.
+func (fd *FileDescription) Dentry() *Dentry {
+ return fd.vd.dentry
+}
+
+// VirtualDentry returns the location at which fd was opened. It does not take
+// a reference on the returned VirtualDentry.
+func (fd *FileDescription) VirtualDentry() VirtualDentry {
+ return fd.vd
+}
+
+// StatusFlags returns file description status flags, as for fcntl(F_GETFL).
+func (fd *FileDescription) StatusFlags() uint32 {
+ return atomic.LoadUint32(&fd.statusFlags)
+}
+
+// SetStatusFlags sets file description status flags, as for fcntl(F_SETFL).
+func (fd *FileDescription) SetStatusFlags(ctx context.Context, creds *auth.Credentials, flags uint32) error {
+ // Compare Linux's fs/fcntl.c:setfl().
+ oldFlags := fd.StatusFlags()
+ // Linux documents this check as "O_APPEND cannot be cleared if the file is
+ // marked as append-only and the file is open for write", which would make
+ // sense. However, the check as actually implemented seems to be "O_APPEND
+ // cannot be changed if the file is marked as append-only".
+ if (flags^oldFlags)&linux.O_APPEND != 0 {
+ stat, err := fd.impl.Stat(ctx, StatOptions{
+ // There is no mask bit for stx_attributes.
+ Mask: 0,
+ // Linux just reads inode::i_flags directly.
+ Sync: linux.AT_STATX_DONT_SYNC,
+ })
+ if err != nil {
+ return err
+ }
+ if (stat.AttributesMask&linux.STATX_ATTR_APPEND != 0) && (stat.Attributes&linux.STATX_ATTR_APPEND != 0) {
+ return syserror.EPERM
+ }
+ }
+ if (flags&linux.O_NOATIME != 0) && (oldFlags&linux.O_NOATIME == 0) {
+ stat, err := fd.impl.Stat(ctx, StatOptions{
+ Mask: linux.STATX_UID,
+ // Linux's inode_owner_or_capable() just reads inode::i_uid
+ // directly.
+ Sync: linux.AT_STATX_DONT_SYNC,
+ })
+ if err != nil {
+ return err
+ }
+ if stat.Mask&linux.STATX_UID == 0 {
+ return syserror.EPERM
+ }
+ if !CanActAsOwner(creds, auth.KUID(stat.UID)) {
+ return syserror.EPERM
+ }
+ }
+ if flags&linux.O_DIRECT != 0 && !fd.opts.AllowDirectIO {
+ return syserror.EINVAL
+ }
+ // TODO(jamieliu): FileDescriptionImpl.SetOAsync()?
+ const settableFlags = linux.O_APPEND | linux.O_ASYNC | linux.O_DIRECT | linux.O_NOATIME | linux.O_NONBLOCK
+ atomic.StoreUint32(&fd.statusFlags, (oldFlags&^settableFlags)|(flags&settableFlags))
+ return nil
+}
+
+// Impl returns the FileDescriptionImpl associated with fd.
+func (fd *FileDescription) Impl() FileDescriptionImpl {
+ return fd.impl
+}
+
// FileDescriptionImpl contains implementation details for an FileDescription.
// Implementations of FileDescriptionImpl should contain their associated
// FileDescription by value as their first field.
@@ -104,14 +203,6 @@ type FileDescriptionImpl interface {
// prevent the file descriptor from being closed.
OnClose(ctx context.Context) error
- // StatusFlags returns file description status flags, as for
- // fcntl(F_GETFL).
- StatusFlags(ctx context.Context) (uint32, error)
-
- // SetStatusFlags sets file description status flags, as for
- // fcntl(F_SETFL).
- SetStatusFlags(ctx context.Context, flags uint32) error
-
// Stat returns metadata for the file represented by the FileDescription.
Stat(ctx context.Context, opts StatOptions) (linux.Statx, error)
@@ -185,7 +276,21 @@ type FileDescriptionImpl interface {
// Ioctl implements the ioctl(2) syscall.
Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error)
- // TODO: extended attributes; file locking
+ // Listxattr returns all extended attribute names for the file.
+ Listxattr(ctx context.Context) ([]string, error)
+
+ // Getxattr returns the value associated with the given extended attribute
+ // for the file.
+ Getxattr(ctx context.Context, name string) (string, error)
+
+ // Setxattr changes the value associated with the given extended attribute
+ // for the file.
+ Setxattr(ctx context.Context, opts SetxattrOptions) error
+
+ // Removexattr removes the given extended attribute from the file.
+ Removexattr(ctx context.Context, name string) error
+
+ // TODO: file locking
}
// Dirent holds the information contained in struct linux_dirent64.
@@ -214,3 +319,160 @@ type IterDirentsCallback interface {
// called.
Handle(dirent Dirent) bool
}
+
+// OnClose is called when a file descriptor representing the FileDescription is
+// closed. Returning a non-nil error should not prevent the file descriptor
+// from being closed.
+func (fd *FileDescription) OnClose(ctx context.Context) error {
+ return fd.impl.OnClose(ctx)
+}
+
+// Stat returns metadata for the file represented by fd.
+func (fd *FileDescription) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) {
+ return fd.impl.Stat(ctx, opts)
+}
+
+// SetStat updates metadata for the file represented by fd.
+func (fd *FileDescription) SetStat(ctx context.Context, opts SetStatOptions) error {
+ return fd.impl.SetStat(ctx, opts)
+}
+
+// StatFS returns metadata for the filesystem containing the file represented
+// by fd.
+func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) {
+ return fd.impl.StatFS(ctx)
+}
+
+// PRead reads from the file represented by fd into dst, starting at the given
+// offset, and returns the number of bytes read. PRead is permitted to return
+// partial reads with a nil error.
+func (fd *FileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) {
+ return fd.impl.PRead(ctx, dst, offset, opts)
+}
+
+// Read is similar to PRead, but does not specify an offset.
+func (fd *FileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) {
+ return fd.impl.Read(ctx, dst, opts)
+}
+
+// PWrite writes src to the file represented by fd, starting at the given
+// offset, and returns the number of bytes written. PWrite is permitted to
+// return partial writes with a nil error.
+func (fd *FileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) {
+ return fd.impl.PWrite(ctx, src, offset, opts)
+}
+
+// Write is similar to PWrite, but does not specify an offset.
+func (fd *FileDescription) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) {
+ return fd.impl.Write(ctx, src, opts)
+}
+
+// IterDirents invokes cb on each entry in the directory represented by fd. If
+// IterDirents has been called since the last call to Seek, it continues
+// iteration from the end of the last call.
+func (fd *FileDescription) IterDirents(ctx context.Context, cb IterDirentsCallback) error {
+ return fd.impl.IterDirents(ctx, cb)
+}
+
+// Seek changes fd's offset (assuming one exists) and returns its new value.
+func (fd *FileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ return fd.impl.Seek(ctx, offset, whence)
+}
+
+// Sync has the semantics of fsync(2).
+func (fd *FileDescription) Sync(ctx context.Context) error {
+ return fd.impl.Sync(ctx)
+}
+
+// ConfigureMMap mutates opts to implement mmap(2) for the file represented by
+// fd.
+func (fd *FileDescription) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ return fd.impl.ConfigureMMap(ctx, opts)
+}
+
+// Ioctl implements the ioctl(2) syscall.
+func (fd *FileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return fd.impl.Ioctl(ctx, uio, args)
+}
+
+// Listxattr returns all extended attribute names for the file represented by
+// fd.
+func (fd *FileDescription) Listxattr(ctx context.Context) ([]string, error) {
+ names, err := fd.impl.Listxattr(ctx)
+ if err == syserror.ENOTSUP {
+ // Linux doesn't actually return ENOTSUP in this case; instead,
+ // fs/xattr.c:vfs_listxattr() falls back to allowing the security
+ // subsystem to return security extended attributes, which by default
+ // don't exist.
+ return nil, nil
+ }
+ return names, err
+}
+
+// Getxattr returns the value associated with the given extended attribute for
+// the file represented by fd.
+func (fd *FileDescription) Getxattr(ctx context.Context, name string) (string, error) {
+ return fd.impl.Getxattr(ctx, name)
+}
+
+// Setxattr changes the value associated with the given extended attribute for
+// the file represented by fd.
+func (fd *FileDescription) Setxattr(ctx context.Context, opts SetxattrOptions) error {
+ return fd.impl.Setxattr(ctx, opts)
+}
+
+// Removexattr removes the given extended attribute from the file represented
+// by fd.
+func (fd *FileDescription) Removexattr(ctx context.Context, name string) error {
+ return fd.impl.Removexattr(ctx, name)
+}
+
+// SyncFS instructs the filesystem containing fd to execute the semantics of
+// syncfs(2).
+func (fd *FileDescription) SyncFS(ctx context.Context) error {
+ return fd.vd.mount.fs.impl.Sync(ctx)
+}
+
+// MappedName implements memmap.MappingIdentity.MappedName.
+func (fd *FileDescription) MappedName(ctx context.Context) string {
+ vfsroot := RootFromContext(ctx)
+ s, _ := fd.vd.mount.vfs.PathnameWithDeleted(ctx, vfsroot, fd.vd)
+ if vfsroot.Ok() {
+ vfsroot.DecRef()
+ }
+ return s
+}
+
+// DeviceID implements memmap.MappingIdentity.DeviceID.
+func (fd *FileDescription) DeviceID() uint64 {
+ stat, err := fd.impl.Stat(context.Background(), StatOptions{
+ // There is no STATX_DEV; we assume that Stat will return it if it's
+ // available regardless of mask.
+ Mask: 0,
+ // fs/proc/task_mmu.c:show_map_vma() just reads inode::i_sb->s_dev
+ // directly.
+ Sync: linux.AT_STATX_DONT_SYNC,
+ })
+ if err != nil {
+ return 0
+ }
+ return uint64(linux.MakeDeviceID(uint16(stat.DevMajor), stat.DevMinor))
+}
+
+// InodeID implements memmap.MappingIdentity.InodeID.
+func (fd *FileDescription) InodeID() uint64 {
+ stat, err := fd.impl.Stat(context.Background(), StatOptions{
+ Mask: linux.STATX_INO,
+ // fs/proc/task_mmu.c:show_map_vma() just reads inode::i_ino directly.
+ Sync: linux.AT_STATX_DONT_SYNC,
+ })
+ if err != nil || stat.Mask&linux.STATX_INO == 0 {
+ return 0
+ }
+ return stat.Ino
+}
+
+// Msync implements memmap.MappingIdentity.Msync.
+func (fd *FileDescription) Msync(ctx context.Context, mr memmap.MappableRange) error {
+ return fd.impl.Sync(ctx)
+}
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index 4fbad7840..3df49991c 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -127,6 +127,31 @@ func (FileDescriptionDefaultImpl) Ioctl(ctx context.Context, uio usermem.IO, arg
return 0, syserror.ENOTTY
}
+// Listxattr implements FileDescriptionImpl.Listxattr analogously to
+// inode_operations::listxattr == NULL in Linux.
+func (FileDescriptionDefaultImpl) Listxattr(ctx context.Context) ([]string, error) {
+ // This isn't exactly accurate; see FileDescription.Listxattr.
+ return nil, syserror.ENOTSUP
+}
+
+// Getxattr implements FileDescriptionImpl.Getxattr analogously to
+// inode::i_opflags & IOP_XATTR == 0 in Linux.
+func (FileDescriptionDefaultImpl) Getxattr(ctx context.Context, name string) (string, error) {
+ return "", syserror.ENOTSUP
+}
+
+// Setxattr implements FileDescriptionImpl.Setxattr analogously to
+// inode::i_opflags & IOP_XATTR == 0 in Linux.
+func (FileDescriptionDefaultImpl) Setxattr(ctx context.Context, opts SetxattrOptions) error {
+ return syserror.ENOTSUP
+}
+
+// Removexattr implements FileDescriptionImpl.Removexattr analogously to
+// inode::i_opflags & IOP_XATTR == 0 in Linux.
+func (FileDescriptionDefaultImpl) Removexattr(ctx context.Context, name string) error {
+ return syserror.ENOTSUP
+}
+
// DirectoryFileDescriptionDefaultImpl may be embedded by implementations of
// FileDescriptionImpl that always represent directories to obtain
// implementations of non-directory I/O methods that return EISDIR.
@@ -252,3 +277,12 @@ func (fd *DynamicBytesFileDescriptionImpl) Seek(ctx context.Context, offset int6
fd.off = offset
return offset, nil
}
+
+// GenericConfigureMMap may be used by most implementations of
+// FileDescriptionImpl.ConfigureMMap.
+func GenericConfigureMMap(fd *FileDescription, m memmap.Mappable, opts *memmap.MMapOpts) error {
+ opts.Mappable = m
+ opts.MappingIdentity = fd
+ fd.IncRef()
+ return nil
+}
diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go
index 511b829fc..678be07fe 100644
--- a/pkg/sentry/vfs/file_description_impl_util_test.go
+++ b/pkg/sentry/vfs/file_description_impl_util_test.go
@@ -48,7 +48,7 @@ type genCountFD struct {
func newGenCountFD(mnt *Mount, vfsd *Dentry) *FileDescription {
var fd genCountFD
- fd.vfsfd.Init(&fd, mnt, vfsd)
+ fd.vfsfd.Init(&fd, 0 /* statusFlags */, mnt, vfsd, &FileDescriptionOptions{})
fd.DynamicBytesFileDescriptionImpl.SetDataSource(&fd)
return &fd.vfsfd
}
@@ -90,7 +90,7 @@ func TestGenCountFD(t *testing.T) {
vfsObj := New() // vfs.New()
vfsObj.MustRegisterFilesystemType("testfs", FDTestFilesystemType{})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "testfs", &NewFilesystemOptions{})
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "testfs", &GetFilesystemOptions{})
if err != nil {
t.Fatalf("failed to create testfs root mount: %v", err)
}
@@ -103,7 +103,7 @@ func TestGenCountFD(t *testing.T) {
// The first read causes Generate to be called to fill the FD's buffer.
buf := make([]byte, 2)
ioseq := usermem.BytesIOSequence(buf)
- n, err := fd.Impl().Read(ctx, ioseq, ReadOptions{})
+ n, err := fd.Read(ctx, ioseq, ReadOptions{})
if n != 1 || (err != nil && err != io.EOF) {
t.Fatalf("first Read: got (%d, %v), wanted (1, nil or EOF)", n, err)
}
@@ -112,17 +112,17 @@ func TestGenCountFD(t *testing.T) {
}
// A second read without seeking is still at EOF.
- n, err = fd.Impl().Read(ctx, ioseq, ReadOptions{})
+ n, err = fd.Read(ctx, ioseq, ReadOptions{})
if n != 0 || err != io.EOF {
t.Fatalf("second Read: got (%d, %v), wanted (0, EOF)", n, err)
}
// Seeking to the beginning of the file causes it to be regenerated.
- n, err = fd.Impl().Seek(ctx, 0, linux.SEEK_SET)
+ n, err = fd.Seek(ctx, 0, linux.SEEK_SET)
if n != 0 || err != nil {
t.Fatalf("Seek: got (%d, %v), wanted (0, nil)", n, err)
}
- n, err = fd.Impl().Read(ctx, ioseq, ReadOptions{})
+ n, err = fd.Read(ctx, ioseq, ReadOptions{})
if n != 1 || (err != nil && err != io.EOF) {
t.Fatalf("Read after Seek: got (%d, %v), wanted (1, nil or EOF)", n, err)
}
@@ -131,7 +131,7 @@ func TestGenCountFD(t *testing.T) {
}
// PRead at the beginning of the file also causes it to be regenerated.
- n, err = fd.Impl().PRead(ctx, ioseq, 0, ReadOptions{})
+ n, err = fd.PRead(ctx, ioseq, 0, ReadOptions{})
if n != 1 || (err != nil && err != io.EOF) {
t.Fatalf("PRead: got (%d, %v), wanted (1, nil or EOF)", n, err)
}
diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go
index 7a074b718..b766614e7 100644
--- a/pkg/sentry/vfs/filesystem.go
+++ b/pkg/sentry/vfs/filesystem.go
@@ -18,6 +18,7 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/context"
)
@@ -33,15 +34,28 @@ type Filesystem struct {
// operations.
refs int64
+ // vfs is the VirtualFilesystem that uses this Filesystem. vfs is
+ // immutable.
+ vfs *VirtualFilesystem
+
// impl is the FilesystemImpl associated with this Filesystem. impl is
// immutable. This should be the last field in Dentry.
impl FilesystemImpl
}
// Init must be called before first use of fs.
-func (fs *Filesystem) Init(impl FilesystemImpl) {
+func (fs *Filesystem) Init(vfsObj *VirtualFilesystem, impl FilesystemImpl) {
fs.refs = 1
+ fs.vfs = vfsObj
fs.impl = impl
+ vfsObj.filesystemsMu.Lock()
+ vfsObj.filesystems[fs] = struct{}{}
+ vfsObj.filesystemsMu.Unlock()
+}
+
+// VirtualFilesystem returns the containing VirtualFilesystem.
+func (fs *Filesystem) VirtualFilesystem() *VirtualFilesystem {
+ return fs.vfs
}
// Impl returns the FilesystemImpl associated with fs.
@@ -49,14 +63,35 @@ func (fs *Filesystem) Impl() FilesystemImpl {
return fs.impl
}
-func (fs *Filesystem) incRef() {
+// IncRef increments fs' reference count.
+func (fs *Filesystem) IncRef() {
if atomic.AddInt64(&fs.refs, 1) <= 1 {
- panic("Filesystem.incRef() called without holding a reference")
+ panic("Filesystem.IncRef() called without holding a reference")
+ }
+}
+
+// TryIncRef increments fs' reference count and returns true. If fs' reference
+// count is zero, TryIncRef does nothing and returns false.
+//
+// TryIncRef does not require that a reference is held on fs.
+func (fs *Filesystem) TryIncRef() bool {
+ for {
+ refs := atomic.LoadInt64(&fs.refs)
+ if refs <= 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt64(&fs.refs, refs, refs+1) {
+ return true
+ }
}
}
-func (fs *Filesystem) decRef() {
+// DecRef decrements fs' reference count.
+func (fs *Filesystem) DecRef() {
if refs := atomic.AddInt64(&fs.refs, -1); refs == 0 {
+ fs.vfs.filesystemsMu.Lock()
+ delete(fs.vfs.filesystems, fs)
+ fs.vfs.filesystemsMu.Unlock()
fs.impl.Release()
} else if refs < 0 {
panic("Filesystem.decRef() called without holding a reference")
@@ -151,5 +186,70 @@ type FilesystemImpl interface {
// UnlinkAt removes the non-directory file at rp.
UnlinkAt(ctx context.Context, rp *ResolvingPath) error
- // TODO: d_path(); extended attributes; inotify_add_watch(); bind()
+ // ListxattrAt returns all extended attribute names for the file at rp.
+ ListxattrAt(ctx context.Context, rp *ResolvingPath) ([]string, error)
+
+ // GetxattrAt returns the value associated with the given extended
+ // attribute for the file at rp.
+ GetxattrAt(ctx context.Context, rp *ResolvingPath, name string) (string, error)
+
+ // SetxattrAt changes the value associated with the given extended
+ // attribute for the file at rp.
+ SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error
+
+ // RemovexattrAt removes the given extended attribute from the file at rp.
+ RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error
+
+ // PrependPath prepends a path from vd to vd.Mount().Root() to b.
+ //
+ // If vfsroot.Ok(), it is the contextual VFS root; if it is encountered
+ // before vd.Mount().Root(), PrependPath should stop prepending path
+ // components and return a PrependPathAtVFSRootError.
+ //
+ // If traversal of vd.Dentry()'s ancestors encounters an independent
+ // ("root") Dentry that is not vd.Mount().Root() (i.e. vd.Dentry() is not a
+ // descendant of vd.Mount().Root()), PrependPath should stop prepending
+ // path components and return a PrependPathAtNonMountRootError.
+ //
+ // Filesystems for which Dentries do not have meaningful paths may prepend
+ // an arbitrary descriptive string to b and then return a
+ // PrependPathSyntheticError.
+ //
+ // Most implementations can acquire the appropriate locks to ensure that
+ // Dentry.Name() and Dentry.Parent() are fixed for vd.Dentry() and all of
+ // its ancestors, then call GenericPrependPath.
+ //
+ // Preconditions: vd.Mount().Filesystem().Impl() == this FilesystemImpl.
+ PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error
+
+ // TODO: inotify_add_watch(); bind()
+}
+
+// PrependPathAtVFSRootError is returned by implementations of
+// FilesystemImpl.PrependPath() when they encounter the contextual VFS root.
+type PrependPathAtVFSRootError struct{}
+
+// Error implements error.Error.
+func (PrependPathAtVFSRootError) Error() string {
+ return "vfs.FilesystemImpl.PrependPath() reached VFS root"
+}
+
+// PrependPathAtNonMountRootError is returned by implementations of
+// FilesystemImpl.PrependPath() when they encounter an independent ancestor
+// Dentry that is not the Mount root.
+type PrependPathAtNonMountRootError struct{}
+
+// Error implements error.Error.
+func (PrependPathAtNonMountRootError) Error() string {
+ return "vfs.FilesystemImpl.PrependPath() reached root other than Mount root"
+}
+
+// PrependPathSyntheticError is returned by implementations of
+// FilesystemImpl.PrependPath() for which prepended names do not represent real
+// paths.
+type PrependPathSyntheticError struct{}
+
+// Error implements error.Error.
+func (PrependPathSyntheticError) Error() string {
+ return "vfs.FilesystemImpl.PrependPath() prepended synthetic name"
}
diff --git a/pkg/sentry/vfs/filesystem_impl_util.go b/pkg/sentry/vfs/filesystem_impl_util.go
new file mode 100644
index 000000000..7315a588e
--- /dev/null
+++ b/pkg/sentry/vfs/filesystem_impl_util.go
@@ -0,0 +1,69 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/fspath"
+)
+
+// GenericParseMountOptions parses a comma-separated list of options of the
+// form "key" or "key=value", where neither key nor value contain commas, and
+// returns it as a map. If str contains duplicate keys, then the last value
+// wins. For example:
+//
+// str = "key0=value0,key1,key2=value2,key0=value3" -> map{'key0':'value3','key1':'','key2':'value2'}
+//
+// GenericParseMountOptions is not appropriate if values may contain commas,
+// e.g. in the case of the mpol mount option for tmpfs(5).
+func GenericParseMountOptions(str string) map[string]string {
+ m := make(map[string]string)
+ for _, opt := range strings.Split(str, ",") {
+ if len(opt) > 0 {
+ res := strings.SplitN(opt, "=", 2)
+ if len(res) == 2 {
+ m[res[0]] = res[1]
+ } else {
+ m[opt] = ""
+ }
+ }
+ }
+ return m
+}
+
+// GenericPrependPath may be used by implementations of
+// FilesystemImpl.PrependPath() for which a single statically-determined lock
+// or set of locks is sufficient to ensure its preconditions (as opposed to
+// e.g. per-Dentry locks).
+//
+// Preconditions: Dentry.Name() and Dentry.Parent() must be held constant for
+// vd.Dentry() and all of its ancestors.
+func GenericPrependPath(vfsroot, vd VirtualDentry, b *fspath.Builder) error {
+ mnt, d := vd.mount, vd.dentry
+ for {
+ if mnt == vfsroot.mount && d == vfsroot.dentry {
+ return PrependPathAtVFSRootError{}
+ }
+ if d == mnt.root {
+ return nil
+ }
+ if d.parent == nil {
+ return PrependPathAtNonMountRootError{}
+ }
+ b.PrependComponent(d.name)
+ d = d.parent
+ }
+}
diff --git a/pkg/sentry/vfs/filesystem_type.go b/pkg/sentry/vfs/filesystem_type.go
index f401ad7f3..c335e206d 100644
--- a/pkg/sentry/vfs/filesystem_type.go
+++ b/pkg/sentry/vfs/filesystem_type.go
@@ -25,21 +25,21 @@ import (
//
// FilesystemType is analogous to Linux's struct file_system_type.
type FilesystemType interface {
- // NewFilesystem returns a Filesystem configured by the given options,
+ // GetFilesystem returns a Filesystem configured by the given options,
// along with its mount root. A reference is taken on the returned
// Filesystem and Dentry.
- NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts NewFilesystemOptions) (*Filesystem, *Dentry, error)
+ GetFilesystem(ctx context.Context, vfsObj *VirtualFilesystem, creds *auth.Credentials, source string, opts GetFilesystemOptions) (*Filesystem, *Dentry, error)
}
-// NewFilesystemOptions contains options to FilesystemType.NewFilesystem.
-type NewFilesystemOptions struct {
+// GetFilesystemOptions contains options to FilesystemType.GetFilesystem.
+type GetFilesystemOptions struct {
// Data is the string passed as the 5th argument to mount(2), which is
// usually a comma-separated list of filesystem-specific mount options.
Data string
// InternalData holds opaque FilesystemType-specific data. There is
// intentionally no way for applications to specify InternalData; if it is
- // not nil, the call to NewFilesystem originates from within the sentry.
+ // not nil, the call to GetFilesystem originates from within the sentry.
InternalData interface{}
}
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index 11702f720..ec23ab0dd 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -18,6 +18,7 @@ import (
"math"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/syserror"
@@ -38,16 +39,12 @@ import (
// Mount is analogous to Linux's struct mount. (gVisor does not distinguish
// between struct mount and struct vfsmount.)
type Mount struct {
- // The lower 63 bits of refs are a reference count. The MSB of refs is set
- // if the Mount has been eagerly unmounted, as by umount(2) without the
- // MNT_DETACH flag. refs is accessed using atomic memory operations.
- refs int64
-
- // The lower 63 bits of writers is the number of calls to
- // Mount.CheckBeginWrite() that have not yet been paired with a call to
- // Mount.EndWrite(). The MSB of writers is set if MS_RDONLY is in effect.
- // writers is accessed using atomic memory operations.
- writers int64
+ // vfs, fs, and root are immutable. References are held on fs and root.
+ //
+ // Invariant: root belongs to fs.
+ vfs *VirtualFilesystem
+ fs *Filesystem
+ root *Dentry
// key is protected by VirtualFilesystem.mountMu and
// VirtualFilesystem.mounts.seq, and may be nil. References are held on
@@ -57,13 +54,29 @@ type Mount struct {
// key.parent.fs.
key mountKey
- // fs, root, and ns are immutable. References are held on fs and root (but
- // not ns).
- //
- // Invariant: root belongs to fs.
- fs *Filesystem
- root *Dentry
- ns *MountNamespace
+ // ns is the namespace in which this Mount was mounted. ns is protected by
+ // VirtualFilesystem.mountMu.
+ ns *MountNamespace
+
+ // The lower 63 bits of refs are a reference count. The MSB of refs is set
+ // if the Mount has been eagerly umounted, as by umount(2) without the
+ // MNT_DETACH flag. refs is accessed using atomic memory operations.
+ refs int64
+
+ // children is the set of all Mounts for which Mount.key.parent is this
+ // Mount. children is protected by VirtualFilesystem.mountMu.
+ children map[*Mount]struct{}
+
+ // umounted is true if VFS.umountRecursiveLocked() has been called on this
+ // Mount. VirtualFilesystem does not hold a reference on Mounts for which
+ // umounted is true. umounted is protected by VirtualFilesystem.mountMu.
+ umounted bool
+
+ // The lower 63 bits of writers is the number of calls to
+ // Mount.CheckBeginWrite() that have not yet been paired with a call to
+ // Mount.EndWrite(). The MSB of writers is set if MS_RDONLY is in effect.
+ // writers is accessed using atomic memory operations.
+ writers int64
}
// A MountNamespace is a collection of Mounts.
@@ -73,13 +86,16 @@ type Mount struct {
//
// MountNamespace is analogous to Linux's struct mnt_namespace.
type MountNamespace struct {
- refs int64 // accessed using atomic memory operations
-
// root is the MountNamespace's root mount. root is immutable.
root *Mount
- // mountpoints contains all Dentries which are mount points in this
- // namespace. mountpoints is protected by VirtualFilesystem.mountMu.
+ // refs is the reference count. refs is accessed using atomic memory
+ // operations.
+ refs int64
+
+ // mountpoints maps all Dentries which are mount points in this namespace
+ // to the number of Mounts for which they are mount points. mountpoints is
+ // protected by VirtualFilesystem.mountMu.
//
// mountpoints is used to determine if a Dentry can be moved or removed
// (which requires that the Dentry is not a mount point in the calling
@@ -89,26 +105,27 @@ type MountNamespace struct {
// MountNamespace; this is required to ensure that
// VFS.PrepareDeleteDentry() and VFS.PrepareRemoveDentry() operate
// correctly on unreferenced MountNamespaces.
- mountpoints map[*Dentry]struct{}
+ mountpoints map[*Dentry]uint32
}
// NewMountNamespace returns a new mount namespace with a root filesystem
// configured by the given arguments. A reference is taken on the returned
// MountNamespace.
-func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *NewFilesystemOptions) (*MountNamespace, error) {
+func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *GetFilesystemOptions) (*MountNamespace, error) {
fsType := vfs.getFilesystemType(fsTypeName)
if fsType == nil {
return nil, syserror.ENODEV
}
- fs, root, err := fsType.NewFilesystem(ctx, creds, source, *opts)
+ fs, root, err := fsType.GetFilesystem(ctx, vfs, creds, source, *opts)
if err != nil {
return nil, err
}
mntns := &MountNamespace{
refs: 1,
- mountpoints: make(map[*Dentry]struct{}),
+ mountpoints: make(map[*Dentry]uint32),
}
mntns.root = &Mount{
+ vfs: vfs,
fs: fs,
root: root,
ns: mntns,
@@ -117,13 +134,13 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth
return mntns, nil
}
-// NewMount creates and mounts a new Filesystem.
-func (vfs *VirtualFilesystem) NewMount(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *NewFilesystemOptions) error {
+// MountAt creates and mounts a Filesystem configured by the given arguments.
+func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) error {
fsType := vfs.getFilesystemType(fsTypeName)
if fsType == nil {
return syserror.ENODEV
}
- fs, root, err := fsType.NewFilesystem(ctx, creds, source, *opts)
+ fs, root, err := fsType.GetFilesystem(ctx, vfs, creds, source, opts.GetFilesystemOptions)
if err != nil {
return err
}
@@ -131,17 +148,19 @@ func (vfs *VirtualFilesystem) NewMount(ctx context.Context, creds *auth.Credenti
// lock ordering.
vd, err := vfs.GetDentryAt(ctx, creds, target, &GetDentryOptions{})
if err != nil {
- root.decRef(fs)
- fs.decRef()
+ root.DecRef()
+ fs.DecRef()
return err
}
vfs.mountMu.Lock()
+ vd.dentry.mu.Lock()
for {
if vd.dentry.IsDisowned() {
+ vd.dentry.mu.Unlock()
vfs.mountMu.Unlock()
vd.DecRef()
- root.decRef(fs)
- fs.decRef()
+ root.DecRef()
+ fs.DecRef()
return syserror.ENOENT
}
// vd might have been mounted over between vfs.GetDentryAt() and
@@ -153,36 +172,272 @@ func (vfs *VirtualFilesystem) NewMount(ctx context.Context, creds *auth.Credenti
if nextmnt == nil {
break
}
- nextmnt.incRef()
- nextmnt.root.incRef(nextmnt.fs)
+ // It's possible that nextmnt has been umounted but not disconnected,
+ // in which case vfs no longer holds a reference on it, and the last
+ // reference may be concurrently dropped even though we're holding
+ // vfs.mountMu.
+ if !nextmnt.tryIncMountedRef() {
+ break
+ }
+ // This can't fail since we're holding vfs.mountMu.
+ nextmnt.root.IncRef()
+ vd.dentry.mu.Unlock()
vd.DecRef()
vd = VirtualDentry{
mount: nextmnt,
dentry: nextmnt.root,
}
+ vd.dentry.mu.Lock()
}
// TODO: Linux requires that either both the mount point and the mount root
// are directories, or neither are, and returns ENOTDIR if this is not the
// case.
mntns := vd.mount.ns
mnt := &Mount{
+ vfs: vfs,
fs: fs,
root: root,
ns: mntns,
refs: 1,
}
- mnt.storeKey(vd.mount, vd.dentry)
+ vfs.mounts.seq.BeginWrite()
+ vfs.connectLocked(mnt, vd, mntns)
+ vfs.mounts.seq.EndWrite()
+ vd.dentry.mu.Unlock()
+ vfs.mountMu.Unlock()
+ return nil
+}
+
+// UmountAt removes the Mount at the given path.
+func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *UmountOptions) error {
+ if opts.Flags&^(linux.MNT_FORCE|linux.MNT_DETACH) != 0 {
+ return syserror.EINVAL
+ }
+
+ // MNT_FORCE is currently unimplemented except for the permission check.
+ if opts.Flags&linux.MNT_FORCE != 0 && creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, creds.UserNamespace.Root()) {
+ return syserror.EPERM
+ }
+
+ vd, err := vfs.GetDentryAt(ctx, creds, pop, &GetDentryOptions{})
+ if err != nil {
+ return err
+ }
+ defer vd.DecRef()
+ if vd.dentry != vd.mount.root {
+ return syserror.EINVAL
+ }
+ vfs.mountMu.Lock()
+ if mntns := MountNamespaceFromContext(ctx); mntns != nil && mntns != vd.mount.ns {
+ vfs.mountMu.Unlock()
+ return syserror.EINVAL
+ }
+
+ // TODO(jamieliu): Linux special-cases umount of the caller's root, which
+ // we don't implement yet (we'll just fail it since the caller holds a
+ // reference on it).
+
+ vfs.mounts.seq.BeginWrite()
+ if opts.Flags&linux.MNT_DETACH == 0 {
+ if len(vd.mount.children) != 0 {
+ vfs.mounts.seq.EndWrite()
+ vfs.mountMu.Unlock()
+ return syserror.EBUSY
+ }
+ // We are holding a reference on vd.mount.
+ expectedRefs := int64(1)
+ if !vd.mount.umounted {
+ expectedRefs = 2
+ }
+ if atomic.LoadInt64(&vd.mount.refs)&^math.MinInt64 != expectedRefs { // mask out MSB
+ vfs.mounts.seq.EndWrite()
+ vfs.mountMu.Unlock()
+ return syserror.EBUSY
+ }
+ }
+ vdsToDecRef, mountsToDecRef := vfs.umountRecursiveLocked(vd.mount, &umountRecursiveOptions{
+ eager: opts.Flags&linux.MNT_DETACH == 0,
+ disconnectHierarchy: true,
+ }, nil, nil)
+ vfs.mounts.seq.EndWrite()
+ vfs.mountMu.Unlock()
+ for _, vd := range vdsToDecRef {
+ vd.DecRef()
+ }
+ for _, mnt := range mountsToDecRef {
+ mnt.DecRef()
+ }
+ return nil
+}
+
+type umountRecursiveOptions struct {
+ // If eager is true, ensure that future calls to Mount.tryIncMountedRef()
+ // on umounted mounts fail.
+ //
+ // eager is analogous to Linux's UMOUNT_SYNC.
+ eager bool
+
+ // If disconnectHierarchy is true, Mounts that are umounted hierarchically
+ // should be disconnected from their parents. (Mounts whose parents are not
+ // umounted, which in most cases means the Mount passed to the initial call
+ // to umountRecursiveLocked, are unconditionally disconnected for
+ // consistency with Linux.)
+ //
+ // disconnectHierarchy is analogous to Linux's !UMOUNT_CONNECTED.
+ disconnectHierarchy bool
+}
+
+// umountRecursiveLocked marks mnt and its descendants as umounted. It does not
+// release mount or dentry references; instead, it appends VirtualDentries and
+// Mounts on which references must be dropped to vdsToDecRef and mountsToDecRef
+// respectively, and returns updated slices. (This is necessary because
+// filesystem locks possibly taken by DentryImpl.DecRef() may precede
+// vfs.mountMu in the lock order, and Mount.DecRef() may lock vfs.mountMu.)
+//
+// umountRecursiveLocked is analogous to Linux's fs/namespace.c:umount_tree().
+//
+// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a
+// writer critical section.
+func (vfs *VirtualFilesystem) umountRecursiveLocked(mnt *Mount, opts *umountRecursiveOptions, vdsToDecRef []VirtualDentry, mountsToDecRef []*Mount) ([]VirtualDentry, []*Mount) {
+ if !mnt.umounted {
+ mnt.umounted = true
+ mountsToDecRef = append(mountsToDecRef, mnt)
+ if parent := mnt.parent(); parent != nil && (opts.disconnectHierarchy || !parent.umounted) {
+ vdsToDecRef = append(vdsToDecRef, vfs.disconnectLocked(mnt))
+ }
+ }
+ if opts.eager {
+ for {
+ refs := atomic.LoadInt64(&mnt.refs)
+ if refs < 0 {
+ break
+ }
+ if atomic.CompareAndSwapInt64(&mnt.refs, refs, refs|math.MinInt64) {
+ break
+ }
+ }
+ }
+ for child := range mnt.children {
+ vdsToDecRef, mountsToDecRef = vfs.umountRecursiveLocked(child, opts, vdsToDecRef, mountsToDecRef)
+ }
+ return vdsToDecRef, mountsToDecRef
+}
+
+// connectLocked makes vd the mount parent/point for mnt. It consumes
+// references held by vd.
+//
+// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a
+// writer critical section. d.mu must be locked. mnt.parent() == nil.
+func (vfs *VirtualFilesystem) connectLocked(mnt *Mount, vd VirtualDentry, mntns *MountNamespace) {
+ mnt.storeKey(vd)
+ if vd.mount.children == nil {
+ vd.mount.children = make(map[*Mount]struct{})
+ }
+ vd.mount.children[mnt] = struct{}{}
atomic.AddUint32(&vd.dentry.mounts, 1)
- mntns.mountpoints[vd.dentry] = struct{}{}
+ mntns.mountpoints[vd.dentry]++
+ vfs.mounts.insertSeqed(mnt)
vfsmpmounts, ok := vfs.mountpoints[vd.dentry]
if !ok {
vfsmpmounts = make(map[*Mount]struct{})
vfs.mountpoints[vd.dentry] = vfsmpmounts
}
vfsmpmounts[mnt] = struct{}{}
- vfs.mounts.Insert(mnt)
- vfs.mountMu.Unlock()
- return nil
+}
+
+// disconnectLocked makes vd have no mount parent/point and returns its old
+// mount parent/point with a reference held.
+//
+// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a
+// writer critical section. mnt.parent() != nil.
+func (vfs *VirtualFilesystem) disconnectLocked(mnt *Mount) VirtualDentry {
+ vd := mnt.loadKey()
+ mnt.storeKey(VirtualDentry{})
+ delete(vd.mount.children, mnt)
+ atomic.AddUint32(&vd.dentry.mounts, math.MaxUint32) // -1
+ mnt.ns.mountpoints[vd.dentry]--
+ if mnt.ns.mountpoints[vd.dentry] == 0 {
+ delete(mnt.ns.mountpoints, vd.dentry)
+ }
+ vfs.mounts.removeSeqed(mnt)
+ vfsmpmounts := vfs.mountpoints[vd.dentry]
+ delete(vfsmpmounts, mnt)
+ if len(vfsmpmounts) == 0 {
+ delete(vfs.mountpoints, vd.dentry)
+ }
+ return vd
+}
+
+// tryIncMountedRef increments mnt's reference count and returns true. If mnt's
+// reference count is already zero, or has been eagerly umounted,
+// tryIncMountedRef does nothing and returns false.
+//
+// tryIncMountedRef does not require that a reference is held on mnt.
+func (mnt *Mount) tryIncMountedRef() bool {
+ for {
+ refs := atomic.LoadInt64(&mnt.refs)
+ if refs <= 0 { // refs < 0 => MSB set => eagerly unmounted
+ return false
+ }
+ if atomic.CompareAndSwapInt64(&mnt.refs, refs, refs+1) {
+ return true
+ }
+ }
+}
+
+// IncRef increments mnt's reference count.
+func (mnt *Mount) IncRef() {
+ // In general, negative values for mnt.refs are valid because the MSB is
+ // the eager-unmount bit.
+ atomic.AddInt64(&mnt.refs, 1)
+}
+
+// DecRef decrements mnt's reference count.
+func (mnt *Mount) DecRef() {
+ refs := atomic.AddInt64(&mnt.refs, -1)
+ if refs&^math.MinInt64 == 0 { // mask out MSB
+ var vd VirtualDentry
+ if mnt.parent() != nil {
+ mnt.vfs.mountMu.Lock()
+ mnt.vfs.mounts.seq.BeginWrite()
+ vd = mnt.vfs.disconnectLocked(mnt)
+ mnt.vfs.mounts.seq.EndWrite()
+ mnt.vfs.mountMu.Unlock()
+ }
+ mnt.root.DecRef()
+ mnt.fs.DecRef()
+ if vd.Ok() {
+ vd.DecRef()
+ }
+ }
+}
+
+// IncRef increments mntns' reference count.
+func (mntns *MountNamespace) IncRef() {
+ if atomic.AddInt64(&mntns.refs, 1) <= 1 {
+ panic("MountNamespace.IncRef() called without holding a reference")
+ }
+}
+
+// DecRef decrements mntns' reference count.
+func (mntns *MountNamespace) DecRef(vfs *VirtualFilesystem) {
+ if refs := atomic.AddInt64(&mntns.refs, -1); refs == 0 {
+ vfs.mountMu.Lock()
+ vfs.mounts.seq.BeginWrite()
+ vdsToDecRef, mountsToDecRef := vfs.umountRecursiveLocked(mntns.root, &umountRecursiveOptions{
+ disconnectHierarchy: true,
+ }, nil, nil)
+ vfs.mounts.seq.EndWrite()
+ vfs.mountMu.Unlock()
+ for _, vd := range vdsToDecRef {
+ vd.DecRef()
+ }
+ for _, mnt := range mountsToDecRef {
+ mnt.DecRef()
+ }
+ } else if refs < 0 {
+ panic("MountNamespace.DecRef() called without holding a reference")
+ }
}
// getMountAt returns the last Mount in the stack mounted at (mnt, d). It takes
@@ -223,7 +478,7 @@ retryFirst:
// Raced with umount.
continue
}
- mnt.decRef()
+ mnt.DecRef()
mnt = next
d = next.root
}
@@ -231,12 +486,12 @@ retryFirst:
}
// getMountpointAt returns the mount point for the stack of Mounts including
-// mnt. It takes a reference on the returned Mount and Dentry. If no such mount
+// mnt. It takes a reference on the returned VirtualDentry. If no such mount
// point exists (i.e. mnt is a root mount), getMountpointAt returns (nil, nil).
//
// Preconditions: References are held on mnt and root. vfsroot is not (mnt,
// mnt.root).
-func (vfs *VirtualFilesystem) getMountpointAt(mnt *Mount, vfsroot VirtualDentry) (*Mount, *Dentry) {
+func (vfs *VirtualFilesystem) getMountpointAt(mnt *Mount, vfsroot VirtualDentry) VirtualDentry {
// The first mount is special-cased:
//
// - The caller must have already checked mnt against vfsroot.
@@ -246,21 +501,26 @@ func (vfs *VirtualFilesystem) getMountpointAt(mnt *Mount, vfsroot VirtualDentry)
// - We don't drop the caller's reference on mnt.
retryFirst:
epoch := vfs.mounts.seq.BeginRead()
- parent, point := mnt.loadKey()
+ parent, point := mnt.parent(), mnt.point()
if !vfs.mounts.seq.ReadOk(epoch) {
goto retryFirst
}
if parent == nil {
- return nil, nil
+ return VirtualDentry{}
}
if !parent.tryIncMountedRef() {
// Raced with umount.
goto retryFirst
}
- if !point.tryIncRef(parent.fs) {
+ if !point.TryIncRef() {
// Since Mount holds a reference on Mount.key.point, this can only
// happen due to a racing change to Mount.key.
- parent.decRef()
+ parent.DecRef()
+ goto retryFirst
+ }
+ if !vfs.mounts.seq.ReadOk(epoch) {
+ point.DecRef()
+ parent.DecRef()
goto retryFirst
}
mnt = parent
@@ -274,7 +534,7 @@ retryFirst:
}
retryNotFirst:
epoch := vfs.mounts.seq.BeginRead()
- parent, point := mnt.loadKey()
+ parent, point := mnt.parent(), mnt.point()
if !vfs.mounts.seq.ReadOk(epoch) {
goto retryNotFirst
}
@@ -285,59 +545,23 @@ retryFirst:
// Raced with umount.
goto retryNotFirst
}
- if !point.tryIncRef(parent.fs) {
+ if !point.TryIncRef() {
// Since Mount holds a reference on Mount.key.point, this can
// only happen due to a racing change to Mount.key.
- parent.decRef()
+ parent.DecRef()
goto retryNotFirst
}
if !vfs.mounts.seq.ReadOk(epoch) {
- point.decRef(parent.fs)
- parent.decRef()
+ point.DecRef()
+ parent.DecRef()
goto retryNotFirst
}
- d.decRef(mnt.fs)
- mnt.decRef()
+ d.DecRef()
+ mnt.DecRef()
mnt = parent
d = point
}
- return mnt, d
-}
-
-// tryIncMountedRef increments mnt's reference count and returns true. If mnt's
-// reference count is already zero, or has been eagerly unmounted,
-// tryIncMountedRef does nothing and returns false.
-//
-// tryIncMountedRef does not require that a reference is held on mnt.
-func (mnt *Mount) tryIncMountedRef() bool {
- for {
- refs := atomic.LoadInt64(&mnt.refs)
- if refs <= 0 { // refs < 0 => MSB set => eagerly unmounted
- return false
- }
- if atomic.CompareAndSwapInt64(&mnt.refs, refs, refs+1) {
- return true
- }
- }
-}
-
-func (mnt *Mount) incRef() {
- // In general, negative values for mnt.refs are valid because the MSB is
- // the eager-unmount bit.
- atomic.AddInt64(&mnt.refs, 1)
-}
-
-func (mnt *Mount) decRef() {
- refs := atomic.AddInt64(&mnt.refs, -1)
- if refs&^math.MinInt64 == 0 { // mask out MSB
- parent, point := mnt.loadKey()
- if point != nil {
- point.decRef(parent.fs)
- parent.decRef()
- }
- mnt.root.decRef(mnt.fs)
- mnt.fs.decRef()
- }
+ return VirtualDentry{mnt, d}
}
// CheckBeginWrite increments the counter of in-progress write operations on
@@ -360,7 +584,7 @@ func (mnt *Mount) EndWrite() {
atomic.AddInt64(&mnt.writers, -1)
}
-// Preconditions: VirtualFilesystem.mountMu must be locked for writing.
+// Preconditions: VirtualFilesystem.mountMu must be locked.
func (mnt *Mount) setReadOnlyLocked(ro bool) error {
if oldRO := atomic.LoadInt64(&mnt.writers) < 0; oldRO == ro {
return nil
@@ -383,22 +607,6 @@ func (mnt *Mount) Filesystem() *Filesystem {
return mnt.fs
}
-// IncRef increments mntns' reference count.
-func (mntns *MountNamespace) IncRef() {
- if atomic.AddInt64(&mntns.refs, 1) <= 1 {
- panic("MountNamespace.IncRef() called without holding a reference")
- }
-}
-
-// DecRef decrements mntns' reference count.
-func (mntns *MountNamespace) DecRef() {
- if refs := atomic.AddInt64(&mntns.refs, 0); refs == 0 {
- // TODO: unmount mntns.root
- } else if refs < 0 {
- panic("MountNamespace.DecRef() called without holding a reference")
- }
-}
-
// Root returns mntns' root. A reference is taken on the returned
// VirtualDentry.
func (mntns *MountNamespace) Root() VirtualDentry {
diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go
index f394d7483..adff0b94b 100644
--- a/pkg/sentry/vfs/mount_test.go
+++ b/pkg/sentry/vfs/mount_test.go
@@ -37,7 +37,7 @@ func TestMountTableInsertLookup(t *testing.T) {
mt.Init()
mount := &Mount{}
- mount.storeKey(&Mount{}, &Dentry{})
+ mount.storeKey(VirtualDentry{&Mount{}, &Dentry{}})
mt.Insert(mount)
if m := mt.Lookup(mount.parent(), mount.point()); m != mount {
@@ -78,18 +78,10 @@ const enableComparativeBenchmarks = false
func newBenchMount() *Mount {
mount := &Mount{}
- mount.storeKey(&Mount{}, &Dentry{})
+ mount.storeKey(VirtualDentry{&Mount{}, &Dentry{}})
return mount
}
-func vdkey(mnt *Mount) VirtualDentry {
- parent, point := mnt.loadKey()
- return VirtualDentry{
- mount: parent,
- dentry: point,
- }
-}
-
func BenchmarkMountTableParallelLookup(b *testing.B) {
for numG, maxG := 1, runtime.GOMAXPROCS(0); numG >= 0 && numG <= maxG; numG *= 2 {
for _, numMounts := range benchNumMounts {
@@ -101,7 +93,7 @@ func BenchmarkMountTableParallelLookup(b *testing.B) {
for i := 0; i < numMounts; i++ {
mount := newBenchMount()
mt.Insert(mount)
- keys = append(keys, vdkey(mount))
+ keys = append(keys, mount.loadKey())
}
var ready sync.WaitGroup
@@ -153,7 +145,7 @@ func BenchmarkMountMapParallelLookup(b *testing.B) {
keys := make([]VirtualDentry, 0, numMounts)
for i := 0; i < numMounts; i++ {
mount := newBenchMount()
- key := vdkey(mount)
+ key := mount.loadKey()
ms[key] = mount
keys = append(keys, key)
}
@@ -208,7 +200,7 @@ func BenchmarkMountSyncMapParallelLookup(b *testing.B) {
keys := make([]VirtualDentry, 0, numMounts)
for i := 0; i < numMounts; i++ {
mount := newBenchMount()
- key := vdkey(mount)
+ key := mount.loadKey()
ms.Store(key, mount)
keys = append(keys, key)
}
@@ -290,7 +282,7 @@ func BenchmarkMountMapNegativeLookup(b *testing.B) {
ms := make(map[VirtualDentry]*Mount)
for i := 0; i < numMounts; i++ {
mount := newBenchMount()
- ms[vdkey(mount)] = mount
+ ms[mount.loadKey()] = mount
}
negkeys := make([]VirtualDentry, 0, numMounts)
for i := 0; i < numMounts; i++ {
@@ -325,7 +317,7 @@ func BenchmarkMountSyncMapNegativeLookup(b *testing.B) {
var ms sync.Map
for i := 0; i < numMounts; i++ {
mount := newBenchMount()
- ms.Store(vdkey(mount), mount)
+ ms.Store(mount.loadKey(), mount)
}
negkeys := make([]VirtualDentry, 0, numMounts)
for i := 0; i < numMounts; i++ {
@@ -379,7 +371,7 @@ func BenchmarkMountMapInsert(b *testing.B) {
b.ResetTimer()
for i := range mounts {
mount := mounts[i]
- ms[vdkey(mount)] = mount
+ ms[mount.loadKey()] = mount
}
}
@@ -399,7 +391,7 @@ func BenchmarkMountSyncMapInsert(b *testing.B) {
b.ResetTimer()
for i := range mounts {
mount := mounts[i]
- ms.Store(vdkey(mount), mount)
+ ms.Store(mount.loadKey(), mount)
}
}
@@ -432,13 +424,13 @@ func BenchmarkMountMapRemove(b *testing.B) {
ms := make(map[VirtualDentry]*Mount)
for i := range mounts {
mount := mounts[i]
- ms[vdkey(mount)] = mount
+ ms[mount.loadKey()] = mount
}
b.ResetTimer()
for i := range mounts {
mount := mounts[i]
- delete(ms, vdkey(mount))
+ delete(ms, mount.loadKey())
}
}
@@ -454,12 +446,12 @@ func BenchmarkMountSyncMapRemove(b *testing.B) {
var ms sync.Map
for i := range mounts {
mount := mounts[i]
- ms.Store(vdkey(mount), mount)
+ ms.Store(mount.loadKey(), mount)
}
b.ResetTimer()
for i := range mounts {
mount := mounts[i]
- ms.Delete(vdkey(mount))
+ ms.Delete(mount.loadKey())
}
}
diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go
index b0511aa40..ab13fa461 100644
--- a/pkg/sentry/vfs/mount_unsafe.go
+++ b/pkg/sentry/vfs/mount_unsafe.go
@@ -13,7 +13,7 @@
// limitations under the License.
// +build go1.12
-// +build !go1.14
+// +build !go1.15
// Check go:linkname function signatures when updating Go version.
@@ -26,7 +26,7 @@ import (
"sync/atomic"
"unsafe"
- "gvisor.dev/gvisor/third_party/gvsync"
+ "gvisor.dev/gvisor/pkg/syncutil"
)
// mountKey represents the location at which a Mount is mounted. It is
@@ -38,16 +38,6 @@ type mountKey struct {
point unsafe.Pointer // *Dentry
}
-// Invariant: mnt.key's fields are nil. parent and point are non-nil.
-func (mnt *Mount) storeKey(parent *Mount, point *Dentry) {
- atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(parent))
- atomic.StorePointer(&mnt.key.point, unsafe.Pointer(point))
-}
-
-func (mnt *Mount) loadKey() (*Mount, *Dentry) {
- return (*Mount)(atomic.LoadPointer(&mnt.key.parent)), (*Dentry)(atomic.LoadPointer(&mnt.key.point))
-}
-
func (mnt *Mount) parent() *Mount {
return (*Mount)(atomic.LoadPointer(&mnt.key.parent))
}
@@ -56,6 +46,19 @@ func (mnt *Mount) point() *Dentry {
return (*Dentry)(atomic.LoadPointer(&mnt.key.point))
}
+func (mnt *Mount) loadKey() VirtualDentry {
+ return VirtualDentry{
+ mount: mnt.parent(),
+ dentry: mnt.point(),
+ }
+}
+
+// Invariant: mnt.key.parent == nil. vd.Ok().
+func (mnt *Mount) storeKey(vd VirtualDentry) {
+ atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount))
+ atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry))
+}
+
// mountTable maps (mount parent, mount point) pairs to mounts. It supports
// efficient concurrent lookup, even in the presence of concurrent mutators
// (provided mutation is sufficiently uncommon).
@@ -72,7 +75,7 @@ type mountTable struct {
// intrinsics and inline assembly, limiting the performance of this
// approach.)
- seq gvsync.SeqCount
+ seq syncutil.SeqCount
seed uint32 // for hashing keys
// size holds both length (number of elements) and capacity (number of
@@ -201,9 +204,19 @@ loop:
// Insert inserts the given mount into mt.
//
-// Preconditions: There are no concurrent mutators of mt. mt must not already
-// contain a Mount with the same mount point and parent.
+// Preconditions: mt must not already contain a Mount with the same mount point
+// and parent.
func (mt *mountTable) Insert(mount *Mount) {
+ mt.seq.BeginWrite()
+ mt.insertSeqed(mount)
+ mt.seq.EndWrite()
+}
+
+// insertSeqed inserts the given mount into mt.
+//
+// Preconditions: mt.seq must be in a writer critical section. mt must not
+// already contain a Mount with the same mount point and parent.
+func (mt *mountTable) insertSeqed(mount *Mount) {
hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes)
// We're under the maximum load factor if:
@@ -215,10 +228,8 @@ func (mt *mountTable) Insert(mount *Mount) {
tcap := uintptr(1) << order
if ((tlen + 1) * mtMaxLoadDen) <= (uint64(mtMaxLoadNum) << order) {
// Atomically insert the new element into the table.
- mt.seq.BeginWrite()
atomic.AddUint64(&mt.size, mtSizeLenOne)
mtInsertLocked(mt.slots, tcap, unsafe.Pointer(mount), hash)
- mt.seq.EndWrite()
return
}
@@ -241,8 +252,6 @@ func (mt *mountTable) Insert(mount *Mount) {
for {
oldSlot := (*mountSlot)(oldCur)
if oldSlot.value != nil {
- // Don't need to lock mt.seq yet since newSlots isn't visible
- // to readers.
mtInsertLocked(newSlots, newCap, oldSlot.value, oldSlot.hash)
}
if oldCur == oldLast {
@@ -252,11 +261,9 @@ func (mt *mountTable) Insert(mount *Mount) {
}
// Insert the new element into the new table.
mtInsertLocked(newSlots, newCap, unsafe.Pointer(mount), hash)
- // Atomically switch to the new table.
- mt.seq.BeginWrite()
+ // Switch to the new table.
atomic.AddUint64(&mt.size, mtSizeLenOne|mtSizeOrderOne)
atomic.StorePointer(&mt.slots, newSlots)
- mt.seq.EndWrite()
}
// Preconditions: There are no concurrent mutators of the table (slots, cap).
@@ -294,9 +301,18 @@ func mtInsertLocked(slots unsafe.Pointer, cap uintptr, value unsafe.Pointer, has
// Remove removes the given mount from mt.
//
-// Preconditions: There are no concurrent mutators of mt. mt must contain
-// mount.
+// Preconditions: mt must contain mount.
func (mt *mountTable) Remove(mount *Mount) {
+ mt.seq.BeginWrite()
+ mt.removeSeqed(mount)
+ mt.seq.EndWrite()
+}
+
+// removeSeqed removes the given mount from mt.
+//
+// Preconditions: mt.seq must be in a writer critical section. mt must contain
+// mount.
+func (mt *mountTable) removeSeqed(mount *Mount) {
hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes)
tcap := uintptr(1) << (mt.size & mtSizeOrderMask)
mask := tcap - 1
@@ -311,7 +327,6 @@ func (mt *mountTable) Remove(mount *Mount) {
// backward until we either find an empty slot, or an element that
// is already in its first-probed slot. (This is backward shift
// deletion.)
- mt.seq.BeginWrite()
for {
nextOff := (off + mountSlotBytes) & offmask
nextSlot := (*mountSlot)(unsafe.Pointer(uintptr(slots) + nextOff))
@@ -330,7 +345,6 @@ func (mt *mountTable) Remove(mount *Mount) {
}
atomic.StorePointer(&slot.value, nil)
atomic.AddUint64(&mt.size, mtSizeLenNegOne)
- mt.seq.EndWrite()
return
}
if checkInvariants && slotValue == nil {
diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go
index 3aa73d911..97ee4a446 100644
--- a/pkg/sentry/vfs/options.go
+++ b/pkg/sentry/vfs/options.go
@@ -46,6 +46,12 @@ type MknodOptions struct {
DevMinor uint32
}
+// MountOptions contains options to VirtualFilesystem.MountAt().
+type MountOptions struct {
+ // GetFilesystemOptions contains options to FilesystemType.GetFilesystem().
+ GetFilesystemOptions GetFilesystemOptions
+}
+
// OpenOptions contains options to VirtualFilesystem.OpenAt() and
// FilesystemImpl.OpenAt().
type OpenOptions struct {
@@ -95,6 +101,20 @@ type SetStatOptions struct {
Stat linux.Statx
}
+// SetxattrOptions contains options to VirtualFilesystem.SetxattrAt(),
+// FilesystemImpl.SetxattrAt(), FileDescription.Setxattr(), and
+// FileDescriptionImpl.Setxattr().
+type SetxattrOptions struct {
+ // Name is the name of the extended attribute being mutated.
+ Name string
+
+ // Value is the extended attribute's new value.
+ Value string
+
+ // Flags contains flags as specified for setxattr/lsetxattr/fsetxattr(2).
+ Flags uint32
+}
+
// StatOptions contains options to VirtualFilesystem.StatAt(),
// FilesystemImpl.StatAt(), FileDescription.Stat(), and
// FileDescriptionImpl.Stat().
@@ -114,6 +134,12 @@ type StatOptions struct {
Sync uint32
}
+// UmountOptions contains options to VirtualFilesystem.UmountAt().
+type UmountOptions struct {
+ // Flags contains flags as specified for umount2(2).
+ Flags uint32
+}
+
// WriteOptions contains options to FileDescription.PWrite(),
// FileDescriptionImpl.PWrite(), FileDescription.Write(), and
// FileDescriptionImpl.Write().
diff --git a/pkg/sentry/vfs/pathname.go b/pkg/sentry/vfs/pathname.go
new file mode 100644
index 000000000..8e155654f
--- /dev/null
+++ b/pkg/sentry/vfs/pathname.go
@@ -0,0 +1,153 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+var fspathBuilderPool = sync.Pool{
+ New: func() interface{} {
+ return &fspath.Builder{}
+ },
+}
+
+func getFSPathBuilder() *fspath.Builder {
+ return fspathBuilderPool.Get().(*fspath.Builder)
+}
+
+func putFSPathBuilder(b *fspath.Builder) {
+ // No methods can be called on b after b.String(), so reset it to its zero
+ // value (as returned by fspathBuilderPool.New) instead.
+ *b = fspath.Builder{}
+ fspathBuilderPool.Put(b)
+}
+
+// PathnameWithDeleted returns an absolute pathname to vd, consistent with
+// Linux's d_path(). In particular, if vd.Dentry() has been disowned,
+// PathnameWithDeleted appends " (deleted)" to the returned pathname.
+func (vfs *VirtualFilesystem) PathnameWithDeleted(ctx context.Context, vfsroot, vd VirtualDentry) (string, error) {
+ b := getFSPathBuilder()
+ defer putFSPathBuilder(b)
+ haveRef := false
+ defer func() {
+ if haveRef {
+ vd.DecRef()
+ }
+ }()
+
+ origD := vd.dentry
+loop:
+ for {
+ err := vd.mount.fs.impl.PrependPath(ctx, vfsroot, vd, b)
+ switch err.(type) {
+ case nil:
+ if vd.mount == vfsroot.mount && vd.mount.root == vfsroot.dentry {
+ // GenericPrependPath() will have returned
+ // PrependPathAtVFSRootError in this case since it checks
+ // against vfsroot before mnt.root, but other implementations
+ // of FilesystemImpl.PrependPath() may return nil instead.
+ break loop
+ }
+ nextVD := vfs.getMountpointAt(vd.mount, vfsroot)
+ if !nextVD.Ok() {
+ break loop
+ }
+ if haveRef {
+ vd.DecRef()
+ }
+ vd = nextVD
+ haveRef = true
+ // continue loop
+ case PrependPathSyntheticError:
+ // Skip prepending "/" and appending " (deleted)".
+ return b.String(), nil
+ case PrependPathAtVFSRootError, PrependPathAtNonMountRootError:
+ break loop
+ default:
+ return "", err
+ }
+ }
+ b.PrependByte('/')
+ if origD.IsDisowned() {
+ b.AppendString(" (deleted)")
+ }
+ return b.String(), nil
+}
+
+// PathnameForGetcwd returns an absolute pathname to vd, consistent with
+// Linux's sys_getcwd().
+func (vfs *VirtualFilesystem) PathnameForGetcwd(ctx context.Context, vfsroot, vd VirtualDentry) (string, error) {
+ if vd.dentry.IsDisowned() {
+ return "", syserror.ENOENT
+ }
+
+ b := getFSPathBuilder()
+ defer putFSPathBuilder(b)
+ haveRef := false
+ defer func() {
+ if haveRef {
+ vd.DecRef()
+ }
+ }()
+ unreachable := false
+loop:
+ for {
+ err := vd.mount.fs.impl.PrependPath(ctx, vfsroot, vd, b)
+ switch err.(type) {
+ case nil:
+ if vd.mount == vfsroot.mount && vd.mount.root == vfsroot.dentry {
+ break loop
+ }
+ nextVD := vfs.getMountpointAt(vd.mount, vfsroot)
+ if !nextVD.Ok() {
+ unreachable = true
+ break loop
+ }
+ if haveRef {
+ vd.DecRef()
+ }
+ vd = nextVD
+ haveRef = true
+ case PrependPathAtVFSRootError:
+ break loop
+ case PrependPathAtNonMountRootError, PrependPathSyntheticError:
+ unreachable = true
+ break loop
+ default:
+ return "", err
+ }
+ }
+ b.PrependByte('/')
+ if unreachable {
+ b.PrependString("(unreachable)")
+ }
+ return b.String(), nil
+}
+
+// As of this writing, we do not have equivalents to:
+//
+// - d_absolute_path(), which returns EINVAL if (effectively) any call to
+// FilesystemImpl.PrependPath() would return PrependPathAtNonMountRootError.
+//
+// - dentry_path(), which does not walk up mounts (and only returns the path
+// relative to Filesystem root), but also appends "//deleted" for disowned
+// Dentries.
+//
+// These should be added as necessary.
diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go
index f8e74355c..f1edb0680 100644
--- a/pkg/sentry/vfs/permissions.go
+++ b/pkg/sentry/vfs/permissions.go
@@ -119,3 +119,65 @@ func MayWriteFileWithOpenFlags(flags uint32) bool {
return false
}
}
+
+// CheckSetStat checks that creds has permission to change the metadata of a
+// file with the given permissions, UID, and GID as specified by stat, subject
+// to the rules of Linux's fs/attr.c:setattr_prepare().
+func CheckSetStat(creds *auth.Credentials, stat *linux.Statx, mode uint16, kuid auth.KUID, kgid auth.KGID) error {
+ if stat.Mask&linux.STATX_MODE != 0 {
+ if !CanActAsOwner(creds, kuid) {
+ return syserror.EPERM
+ }
+ // TODO(b/30815691): "If the calling process is not privileged (Linux:
+ // does not have the CAP_FSETID capability), and the group of the file
+ // does not match the effective group ID of the process or one of its
+ // supplementary group IDs, the S_ISGID bit will be turned off, but
+ // this will not cause an error to be returned." - chmod(2)
+ }
+ if stat.Mask&linux.STATX_UID != 0 {
+ if !((creds.EffectiveKUID == kuid && auth.KUID(stat.UID) == kuid) ||
+ HasCapabilityOnFile(creds, linux.CAP_CHOWN, kuid, kgid)) {
+ return syserror.EPERM
+ }
+ }
+ if stat.Mask&linux.STATX_GID != 0 {
+ if !((creds.EffectiveKUID == kuid && creds.InGroup(auth.KGID(stat.GID))) ||
+ HasCapabilityOnFile(creds, linux.CAP_CHOWN, kuid, kgid)) {
+ return syserror.EPERM
+ }
+ }
+ if stat.Mask&(linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME) != 0 {
+ if !CanActAsOwner(creds, kuid) {
+ if (stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW) ||
+ (stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec != linux.UTIME_NOW) ||
+ (stat.Mask&linux.STATX_CTIME != 0 && stat.Ctime.Nsec != linux.UTIME_NOW) {
+ return syserror.EPERM
+ }
+ // isDir is irrelevant in the following call to
+ // GenericCheckPermissions since ats == MayWrite means that
+ // CAP_DAC_READ_SEARCH does not apply, and CAP_DAC_OVERRIDE
+ // applies, regardless of isDir.
+ if err := GenericCheckPermissions(creds, MayWrite, false /* isDir */, mode, kuid, kgid); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// CanActAsOwner returns true if creds can act as the owner of a file with the
+// given owning UID, consistent with Linux's
+// fs/inode.c:inode_owner_or_capable().
+func CanActAsOwner(creds *auth.Credentials, kuid auth.KUID) bool {
+ if creds.EffectiveKUID == kuid {
+ return true
+ }
+ return creds.HasCapability(linux.CAP_FOWNER) && creds.UserNamespace.MapFromKUID(kuid).Ok()
+}
+
+// HasCapabilityOnFile returns true if creds has the given capability with
+// respect to a file with the given owning UID and GID, consistent with Linux's
+// kernel/capability.c:capable_wrt_inode_uidgid().
+func HasCapabilityOnFile(creds *auth.Credentials, cp linux.Capability, kuid auth.KUID, kgid auth.KGID) bool {
+ return creds.HasCapability(cp) && creds.UserNamespace.MapFromKUID(kuid).Ok() && creds.UserNamespace.MapFromKGID(kgid).Ok()
+}
diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go
index 8d05c8583..d580fd39e 100644
--- a/pkg/sentry/vfs/resolving_path.go
+++ b/pkg/sentry/vfs/resolving_path.go
@@ -85,11 +85,11 @@ func init() {
// so error "constants" are really mutable vars, necessitating somewhat
// expensive interface object comparisons.
-type resolveMountRootError struct{}
+type resolveMountRootOrJumpError struct{}
// Error implements error.Error.
-func (resolveMountRootError) Error() string {
- return "resolving mount root"
+func (resolveMountRootOrJumpError) Error() string {
+ return "resolving mount root or jump"
}
type resolveMountPointError struct{}
@@ -149,20 +149,20 @@ func (vfs *VirtualFilesystem) putResolvingPath(rp *ResolvingPath) {
func (rp *ResolvingPath) decRefStartAndMount() {
if rp.flags&rpflagsHaveStartRef != 0 {
- rp.start.decRef(rp.mount.fs)
+ rp.start.DecRef()
}
if rp.flags&rpflagsHaveMountRef != 0 {
- rp.mount.decRef()
+ rp.mount.DecRef()
}
}
func (rp *ResolvingPath) releaseErrorState() {
if rp.nextStart != nil {
- rp.nextStart.decRef(rp.nextMount.fs)
+ rp.nextStart.DecRef()
rp.nextStart = nil
}
if rp.nextMount != nil {
- rp.nextMount.decRef()
+ rp.nextMount.DecRef()
rp.nextMount = nil
}
}
@@ -269,12 +269,12 @@ func (rp *ResolvingPath) ResolveParent(d *Dentry) (*Dentry, error) {
parent = d
} else if d == rp.mount.root {
// At mount root ...
- mnt, mntpt := rp.vfs.getMountpointAt(rp.mount, rp.root)
- if mnt != nil {
+ vd := rp.vfs.getMountpointAt(rp.mount, rp.root)
+ if vd.Ok() {
// ... of non-root mount.
- rp.nextMount = mnt
- rp.nextStart = mntpt
- return nil, resolveMountRootError{}
+ rp.nextMount = vd.mount
+ rp.nextStart = vd.dentry
+ return nil, resolveMountRootOrJumpError{}
}
// ... of root mount.
parent = d
@@ -385,11 +385,32 @@ func (rp *ResolvingPath) relpathPrepend(path fspath.Path) {
}
}
+// HandleJump is called when the current path component is a "magic" link to
+// the given VirtualDentry, like /proc/[pid]/fd/[fd]. If the calling Filesystem
+// method should continue path traversal, HandleMagicSymlink updates the path
+// component stream to reflect the magic link target and returns nil. Otherwise
+// it returns a non-nil error.
+//
+// Preconditions: !rp.Done().
+func (rp *ResolvingPath) HandleJump(target VirtualDentry) error {
+ if rp.symlinks >= linux.MaxSymlinkTraversals {
+ return syserror.ELOOP
+ }
+ rp.symlinks++
+ // Consume the path component that represented the magic link.
+ rp.Advance()
+ // Unconditionally return a resolveMountRootOrJumpError, even if the Mount
+ // isn't changing, to force restarting at the new Dentry.
+ target.IncRef()
+ rp.nextMount = target.mount
+ rp.nextStart = target.dentry
+ return resolveMountRootOrJumpError{}
+}
+
func (rp *ResolvingPath) handleError(err error) bool {
switch err.(type) {
- case resolveMountRootError:
- // Switch to the new Mount. We hold references on the Mount and Dentry
- // (from VFS.getMountpointAt()).
+ case resolveMountRootOrJumpError:
+ // Switch to the new Mount. We hold references on the Mount and Dentry.
rp.decRefStartAndMount()
rp.mount = rp.nextMount
rp.start = rp.nextStart
@@ -407,9 +428,8 @@ func (rp *ResolvingPath) handleError(err error) bool {
return true
case resolveMountPointError:
- // Switch to the new Mount. We hold a reference on the Mount (from
- // VFS.getMountAt()), but borrow the reference on the mount root from
- // the Mount.
+ // Switch to the new Mount. We hold a reference on the Mount, but
+ // borrow the reference on the mount root from the Mount.
rp.decRefStartAndMount()
rp.mount = rp.nextMount
rp.start = rp.nextMount.root
diff --git a/pkg/sentry/vfs/syscalls.go b/pkg/sentry/vfs/syscalls.go
deleted file mode 100644
index 23f2b9e08..000000000
--- a/pkg/sentry/vfs/syscalls.go
+++ /dev/null
@@ -1,217 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package vfs
-
-import (
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/syserror"
-)
-
-// PathOperation specifies the path operated on by a VFS method.
-//
-// PathOperation is passed to VFS methods by pointer to reduce memory copying:
-// it's somewhat large and should never escape. (Options structs are passed by
-// pointer to VFS and FileDescription methods for the same reason.)
-type PathOperation struct {
- // Root is the VFS root. References on Root are borrowed from the provider
- // of the PathOperation.
- //
- // Invariants: Root.Ok().
- Root VirtualDentry
-
- // Start is the starting point for the path traversal. References on Start
- // are borrowed from the provider of the PathOperation (i.e. the caller of
- // the VFS method to which the PathOperation was passed).
- //
- // Invariants: Start.Ok(). If Pathname.Absolute, then Start == Root.
- Start VirtualDentry
-
- // Path is the pathname traversed by this operation.
- Pathname string
-
- // If FollowFinalSymlink is true, and the Dentry traversed by the final
- // path component represents a symbolic link, the symbolic link should be
- // followed.
- FollowFinalSymlink bool
-}
-
-// GetDentryAt returns a VirtualDentry representing the given path, at which a
-// file must exist. A reference is taken on the returned VirtualDentry.
-func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetDentryOptions) (VirtualDentry, error) {
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return VirtualDentry{}, err
- }
- for {
- d, err := rp.mount.fs.impl.GetDentryAt(ctx, rp, *opts)
- if err == nil {
- vd := VirtualDentry{
- mount: rp.mount,
- dentry: d,
- }
- rp.mount.incRef()
- vfs.putResolvingPath(rp)
- return vd, nil
- }
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
- return VirtualDentry{}, err
- }
- }
-}
-
-// MkdirAt creates a directory at the given path.
-func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MkdirOptions) error {
- // "Under Linux, apart from the permission bits, the S_ISVTX mode bit is
- // also honored." - mkdir(2)
- opts.Mode &= 01777
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return err
- }
- for {
- err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts)
- if err == nil {
- vfs.putResolvingPath(rp)
- return nil
- }
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
- return err
- }
- }
-}
-
-// OpenAt returns a FileDescription providing access to the file at the given
-// path. A reference is taken on the returned FileDescription.
-func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) {
- // Remove:
- //
- // - O_LARGEFILE, which we always report in FileDescription status flags
- // since only 64-bit architectures are supported at this time.
- //
- // - O_CLOEXEC, which affects file descriptors and therefore must be
- // handled outside of VFS.
- //
- // - Unknown flags.
- opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC | linux.O_APPEND | linux.O_NONBLOCK | linux.O_DSYNC | linux.O_ASYNC | linux.O_DIRECT | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NOATIME | linux.O_SYNC | linux.O_PATH | linux.O_TMPFILE
- // Linux's __O_SYNC (which we call linux.O_SYNC) implies O_DSYNC.
- if opts.Flags&linux.O_SYNC != 0 {
- opts.Flags |= linux.O_DSYNC
- }
- // Linux's __O_TMPFILE (which we call linux.O_TMPFILE) must be specified
- // with O_DIRECTORY and a writable access mode (to ensure that it fails on
- // filesystem implementations that do not support it).
- if opts.Flags&linux.O_TMPFILE != 0 {
- if opts.Flags&linux.O_DIRECTORY == 0 {
- return nil, syserror.EINVAL
- }
- if opts.Flags&linux.O_CREAT != 0 {
- return nil, syserror.EINVAL
- }
- if opts.Flags&linux.O_ACCMODE == linux.O_RDONLY {
- return nil, syserror.EINVAL
- }
- }
- // O_PATH causes most other flags to be ignored.
- if opts.Flags&linux.O_PATH != 0 {
- opts.Flags &= linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_PATH
- }
- // "On Linux, the following bits are also honored in mode: [S_ISUID,
- // S_ISGID, S_ISVTX]" - open(2)
- opts.Mode &= 07777
-
- if opts.Flags&linux.O_NOFOLLOW != 0 {
- pop.FollowFinalSymlink = false
- }
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return nil, err
- }
- if opts.Flags&linux.O_DIRECTORY != 0 {
- rp.mustBeDir = true
- rp.mustBeDirOrig = true
- }
- for {
- fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts)
- if err == nil {
- vfs.putResolvingPath(rp)
- return fd, nil
- }
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
- return nil, err
- }
- }
-}
-
-// StatAt returns metadata for the file at the given path.
-func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *StatOptions) (linux.Statx, error) {
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return linux.Statx{}, err
- }
- for {
- stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts)
- if err == nil {
- vfs.putResolvingPath(rp)
- return stat, nil
- }
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
- return linux.Statx{}, err
- }
- }
-}
-
-// StatusFlags returns file description status flags.
-func (fd *FileDescription) StatusFlags(ctx context.Context) (uint32, error) {
- flags, err := fd.impl.StatusFlags(ctx)
- flags |= linux.O_LARGEFILE
- return flags, err
-}
-
-// SetStatusFlags sets file description status flags.
-func (fd *FileDescription) SetStatusFlags(ctx context.Context, flags uint32) error {
- return fd.impl.SetStatusFlags(ctx, flags)
-}
-
-// TODO:
-//
-// - VFS.SyncAllFilesystems() for sync(2)
-//
-// - Something for syncfs(2)
-//
-// - VFS.LinkAt()
-//
-// - VFS.MknodAt()
-//
-// - VFS.ReadlinkAt()
-//
-// - VFS.RenameAt()
-//
-// - VFS.RmdirAt()
-//
-// - VFS.SetStatAt()
-//
-// - VFS.StatFSAt()
-//
-// - VFS.SymlinkAt()
-//
-// - VFS.UnlinkAt()
-//
-// - FileDescription.(almost everything)
diff --git a/pkg/sentry/vfs/testutil.go b/pkg/sentry/vfs/testutil.go
index 70b192ece..d94117bce 100644
--- a/pkg/sentry/vfs/testutil.go
+++ b/pkg/sentry/vfs/testutil.go
@@ -15,7 +15,10 @@
package vfs
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/syserror"
@@ -33,10 +36,10 @@ type FDTestFilesystem struct {
vfsfs Filesystem
}
-// NewFilesystem implements FilesystemType.NewFilesystem.
-func (fstype FDTestFilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts NewFilesystemOptions) (*Filesystem, *Dentry, error) {
+// GetFilesystem implements FilesystemType.GetFilesystem.
+func (fstype FDTestFilesystemType) GetFilesystem(ctx context.Context, vfsObj *VirtualFilesystem, creds *auth.Credentials, source string, opts GetFilesystemOptions) (*Filesystem, *Dentry, error) {
var fs FDTestFilesystem
- fs.vfsfs.Init(&fs)
+ fs.vfsfs.Init(vfsObj, &fs)
return &fs.vfsfs, fs.NewDentry(), nil
}
@@ -114,6 +117,32 @@ func (fs *FDTestFilesystem) UnlinkAt(ctx context.Context, rp *ResolvingPath) err
return syserror.EPERM
}
+// ListxattrAt implements FilesystemImpl.ListxattrAt.
+func (fs *FDTestFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath) ([]string, error) {
+ return nil, syserror.EPERM
+}
+
+// GetxattrAt implements FilesystemImpl.GetxattrAt.
+func (fs *FDTestFilesystem) GetxattrAt(ctx context.Context, rp *ResolvingPath, name string) (string, error) {
+ return "", syserror.EPERM
+}
+
+// SetxattrAt implements FilesystemImpl.SetxattrAt.
+func (fs *FDTestFilesystem) SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error {
+ return syserror.EPERM
+}
+
+// RemovexattrAt implements FilesystemImpl.RemovexattrAt.
+func (fs *FDTestFilesystem) RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error {
+ return syserror.EPERM
+}
+
+// PrependPath implements FilesystemImpl.PrependPath.
+func (fs *FDTestFilesystem) PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error {
+ b.PrependComponent(fmt.Sprintf("vfs.fdTestDentry:%p", vd.dentry.impl.(*fdTestDentry)))
+ return PrependPathSyntheticError{}
+}
+
type fdTestDentry struct {
vfsd Dentry
}
@@ -126,14 +155,14 @@ func (fs *FDTestFilesystem) NewDentry() *Dentry {
}
// IncRef implements DentryImpl.IncRef.
-func (d *fdTestDentry) IncRef(vfsfs *Filesystem) {
+func (d *fdTestDentry) IncRef() {
}
// TryIncRef implements DentryImpl.TryIncRef.
-func (d *fdTestDentry) TryIncRef(vfsfs *Filesystem) bool {
+func (d *fdTestDentry) TryIncRef() bool {
return true
}
// DecRef implements DentryImpl.DecRef.
-func (d *fdTestDentry) DecRef(vfsfs *Filesystem) {
+func (d *fdTestDentry) DecRef() {
}
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 4a8a69540..e60898d7c 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -16,13 +16,24 @@
//
// Lock order:
//
-// Filesystem implementation locks
+// FilesystemImpl/FileDescriptionImpl locks
// VirtualFilesystem.mountMu
+// Dentry.mu
+// Locks acquired by FilesystemImpls between Prepare{Delete,Rename}Dentry and Commit{Delete,Rename*}Dentry
+// VirtualFilesystem.filesystemsMu
// VirtualFilesystem.fsTypesMu
+//
+// Locking Dentry.mu in multiple Dentries requires holding
+// VirtualFilesystem.mountMu.
package vfs
import (
"sync"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/syserror"
)
// A VirtualFilesystem (VFS for short) combines Filesystems in trees of Mounts.
@@ -33,7 +44,7 @@ type VirtualFilesystem struct {
// mountMu serializes mount mutations.
//
// mountMu is analogous to Linux's namespace_sem.
- mountMu sync.RWMutex
+ mountMu sync.Mutex
// mounts maps (mount parent, mount point) pairs to mounts. (Since mounts
// are uniquely namespaced, including mount parent in the key correctly
@@ -52,7 +63,7 @@ type VirtualFilesystem struct {
// mountpoints maps mount points to mounts at those points in all
// namespaces. mountpoints is protected by mountMu.
//
- // mountpoints is used to find mounts that must be unmounted due to
+ // mountpoints is used to find mounts that must be umounted due to
// removal of a mount point Dentry from another mount namespace. ("A file
// or directory that is a mount point in one namespace that is not a mount
// point in another namespace, may be renamed, unlinked, or removed
@@ -62,6 +73,11 @@ type VirtualFilesystem struct {
// mountpoints is analogous to Linux's mountpoint_hashtable.
mountpoints map[*Dentry]map[*Mount]struct{}
+ // filesystems contains all Filesystems. filesystems is protected by
+ // filesystemsMu.
+ filesystemsMu sync.Mutex
+ filesystems map[*Filesystem]struct{}
+
// fsTypes contains all FilesystemTypes that are usable in the
// VirtualFilesystem. fsTypes is protected by fsTypesMu.
fsTypesMu sync.RWMutex
@@ -72,12 +88,466 @@ type VirtualFilesystem struct {
func New() *VirtualFilesystem {
vfs := &VirtualFilesystem{
mountpoints: make(map[*Dentry]map[*Mount]struct{}),
+ filesystems: make(map[*Filesystem]struct{}),
fsTypes: make(map[string]FilesystemType),
}
vfs.mounts.Init()
return vfs
}
+// PathOperation specifies the path operated on by a VFS method.
+//
+// PathOperation is passed to VFS methods by pointer to reduce memory copying:
+// it's somewhat large and should never escape. (Options structs are passed by
+// pointer to VFS and FileDescription methods for the same reason.)
+type PathOperation struct {
+ // Root is the VFS root. References on Root are borrowed from the provider
+ // of the PathOperation.
+ //
+ // Invariants: Root.Ok().
+ Root VirtualDentry
+
+ // Start is the starting point for the path traversal. References on Start
+ // are borrowed from the provider of the PathOperation (i.e. the caller of
+ // the VFS method to which the PathOperation was passed).
+ //
+ // Invariants: Start.Ok(). If Pathname.Absolute, then Start == Root.
+ Start VirtualDentry
+
+ // Path is the pathname traversed by this operation.
+ Pathname string
+
+ // If FollowFinalSymlink is true, and the Dentry traversed by the final
+ // path component represents a symbolic link, the symbolic link should be
+ // followed.
+ FollowFinalSymlink bool
+}
+
+// GetDentryAt returns a VirtualDentry representing the given path, at which a
+// file must exist. A reference is taken on the returned VirtualDentry.
+func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetDentryOptions) (VirtualDentry, error) {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return VirtualDentry{}, err
+ }
+ for {
+ d, err := rp.mount.fs.impl.GetDentryAt(ctx, rp, *opts)
+ if err == nil {
+ vd := VirtualDentry{
+ mount: rp.mount,
+ dentry: d,
+ }
+ rp.mount.IncRef()
+ vfs.putResolvingPath(rp)
+ return vd, nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return VirtualDentry{}, err
+ }
+ }
+}
+
+// LinkAt creates a hard link at newpop representing the existing file at
+// oldpop.
+func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credentials, oldpop, newpop *PathOperation) error {
+ oldVD, err := vfs.GetDentryAt(ctx, creds, oldpop, &GetDentryOptions{})
+ if err != nil {
+ return err
+ }
+ rp, err := vfs.getResolvingPath(creds, newpop)
+ if err != nil {
+ oldVD.DecRef()
+ return err
+ }
+ for {
+ err := rp.mount.fs.impl.LinkAt(ctx, rp, oldVD)
+ if err == nil {
+ oldVD.DecRef()
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if !rp.handleError(err) {
+ oldVD.DecRef()
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// MkdirAt creates a directory at the given path.
+func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MkdirOptions) error {
+ // "Under Linux, apart from the permission bits, the S_ISVTX mode bit is
+ // also honored." - mkdir(2)
+ opts.Mode &= 0777 | linux.S_ISVTX
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return err
+ }
+ for {
+ err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// MknodAt creates a file of the given mode at the given path. It returns an
+// error from the syserror package.
+func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MknodOptions) error {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return nil
+ }
+ for {
+ if err = rp.mount.fs.impl.MknodAt(ctx, rp, *opts); err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ // Handle mount traversals.
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// OpenAt returns a FileDescription providing access to the file at the given
+// path. A reference is taken on the returned FileDescription.
+func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) {
+ // Remove:
+ //
+ // - O_LARGEFILE, which we always report in FileDescription status flags
+ // since only 64-bit architectures are supported at this time.
+ //
+ // - O_CLOEXEC, which affects file descriptors and therefore must be
+ // handled outside of VFS.
+ //
+ // - Unknown flags.
+ opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC | linux.O_APPEND | linux.O_NONBLOCK | linux.O_DSYNC | linux.O_ASYNC | linux.O_DIRECT | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NOATIME | linux.O_SYNC | linux.O_PATH | linux.O_TMPFILE
+ // Linux's __O_SYNC (which we call linux.O_SYNC) implies O_DSYNC.
+ if opts.Flags&linux.O_SYNC != 0 {
+ opts.Flags |= linux.O_DSYNC
+ }
+ // Linux's __O_TMPFILE (which we call linux.O_TMPFILE) must be specified
+ // with O_DIRECTORY and a writable access mode (to ensure that it fails on
+ // filesystem implementations that do not support it).
+ if opts.Flags&linux.O_TMPFILE != 0 {
+ if opts.Flags&linux.O_DIRECTORY == 0 {
+ return nil, syserror.EINVAL
+ }
+ if opts.Flags&linux.O_CREAT != 0 {
+ return nil, syserror.EINVAL
+ }
+ if opts.Flags&linux.O_ACCMODE == linux.O_RDONLY {
+ return nil, syserror.EINVAL
+ }
+ }
+ // O_PATH causes most other flags to be ignored.
+ if opts.Flags&linux.O_PATH != 0 {
+ opts.Flags &= linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_PATH
+ }
+ // "On Linux, the following bits are also honored in mode: [S_ISUID,
+ // S_ISGID, S_ISVTX]" - open(2)
+ opts.Mode &= 0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX
+
+ if opts.Flags&linux.O_NOFOLLOW != 0 {
+ pop.FollowFinalSymlink = false
+ }
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return nil, err
+ }
+ if opts.Flags&linux.O_DIRECTORY != 0 {
+ rp.mustBeDir = true
+ rp.mustBeDirOrig = true
+ }
+ for {
+ fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return fd, nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return nil, err
+ }
+ }
+}
+
+// ReadlinkAt returns the target of the symbolic link at the given path.
+func (vfs *VirtualFilesystem) ReadlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (string, error) {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return "", err
+ }
+ for {
+ target, err := rp.mount.fs.impl.ReadlinkAt(ctx, rp)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return target, nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return "", err
+ }
+ }
+}
+
+// RenameAt renames the file at oldpop to newpop.
+func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credentials, oldpop, newpop *PathOperation, opts *RenameOptions) error {
+ oldVD, err := vfs.GetDentryAt(ctx, creds, oldpop, &GetDentryOptions{})
+ if err != nil {
+ return err
+ }
+ rp, err := vfs.getResolvingPath(creds, newpop)
+ if err != nil {
+ oldVD.DecRef()
+ return err
+ }
+ for {
+ err := rp.mount.fs.impl.RenameAt(ctx, rp, oldVD, *opts)
+ if err == nil {
+ oldVD.DecRef()
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if !rp.handleError(err) {
+ oldVD.DecRef()
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// RmdirAt removes the directory at the given path.
+func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return err
+ }
+ for {
+ err := rp.mount.fs.impl.RmdirAt(ctx, rp)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// SetStatAt changes metadata for the file at the given path.
+func (vfs *VirtualFilesystem) SetStatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetStatOptions) error {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return err
+ }
+ for {
+ err := rp.mount.fs.impl.SetStatAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// StatAt returns metadata for the file at the given path.
+func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *StatOptions) (linux.Statx, error) {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ for {
+ stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return stat, nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return linux.Statx{}, err
+ }
+ }
+}
+
+// StatFSAt returns metadata for the filesystem containing the file at the
+// given path.
+func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (linux.Statfs, error) {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return linux.Statfs{}, err
+ }
+ for {
+ statfs, err := rp.mount.fs.impl.StatFSAt(ctx, rp)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return statfs, nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return linux.Statfs{}, err
+ }
+ }
+}
+
+// SymlinkAt creates a symbolic link at the given path with the given target.
+func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, target string) error {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return err
+ }
+ for {
+ err := rp.mount.fs.impl.SymlinkAt(ctx, rp, target)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// UnlinkAt deletes the non-directory file at the given path.
+func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return err
+ }
+ for {
+ err := rp.mount.fs.impl.UnlinkAt(ctx, rp)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// ListxattrAt returns all extended attribute names for the file at the given
+// path.
+func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) ([]string, error) {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return nil, err
+ }
+ for {
+ names, err := rp.mount.fs.impl.ListxattrAt(ctx, rp)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return names, nil
+ }
+ if err == syserror.ENOTSUP {
+ // Linux doesn't actually return ENOTSUP in this case; instead,
+ // fs/xattr.c:vfs_listxattr() falls back to allowing the security
+ // subsystem to return security extended attributes, which by
+ // default don't exist.
+ vfs.putResolvingPath(rp)
+ return nil, nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return nil, err
+ }
+ }
+}
+
+// GetxattrAt returns the value associated with the given extended attribute
+// for the file at the given path.
+func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) (string, error) {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return "", err
+ }
+ for {
+ val, err := rp.mount.fs.impl.GetxattrAt(ctx, rp, name)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return val, nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return "", err
+ }
+ }
+}
+
+// SetxattrAt changes the value associated with the given extended attribute
+// for the file at the given path.
+func (vfs *VirtualFilesystem) SetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetxattrOptions) error {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return err
+ }
+ for {
+ err := rp.mount.fs.impl.SetxattrAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// RemovexattrAt removes the given extended attribute from the file at rp.
+func (vfs *VirtualFilesystem) RemovexattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) error {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return err
+ }
+ for {
+ err := rp.mount.fs.impl.RemovexattrAt(ctx, rp, name)
+ if err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
+// SyncAllFilesystems has the semantics of Linux's sync(2).
+func (vfs *VirtualFilesystem) SyncAllFilesystems(ctx context.Context) error {
+ fss := make(map[*Filesystem]struct{})
+ vfs.filesystemsMu.Lock()
+ for fs := range vfs.filesystems {
+ if !fs.TryIncRef() {
+ continue
+ }
+ fss[fs] = struct{}{}
+ }
+ vfs.filesystemsMu.Unlock()
+ var retErr error
+ for fs := range fss {
+ if err := fs.impl.Sync(ctx); err != nil && retErr == nil {
+ retErr = err
+ }
+ fs.DecRef()
+ }
+ return retErr
+}
+
// A VirtualDentry represents a node in a VFS tree, by combining a Dentry
// (which represents a node in a Filesystem's tree) and a Mount (which
// represents the Filesystem's position in a VFS mount tree).
@@ -111,15 +581,15 @@ func (vd VirtualDentry) Ok() bool {
// IncRef increments the reference counts on the Mount and Dentry represented
// by vd.
func (vd VirtualDentry) IncRef() {
- vd.mount.incRef()
- vd.dentry.incRef(vd.mount.fs)
+ vd.mount.IncRef()
+ vd.dentry.IncRef()
}
// DecRef decrements the reference counts on the Mount and Dentry represented
// by vd.
func (vd VirtualDentry) DecRef() {
- vd.dentry.decRef(vd.mount.fs)
- vd.mount.decRef()
+ vd.dentry.DecRef()
+ vd.mount.DecRef()
}
// Mount returns the Mount associated with vd. It does not take a reference on
diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go
index 145102c0d..5e4611333 100644
--- a/pkg/sentry/watchdog/watchdog.go
+++ b/pkg/sentry/watchdog/watchdog.go
@@ -42,8 +42,35 @@ import (
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
)
-// DefaultTimeout is a resonable timeout value for most applications.
-const DefaultTimeout = 3 * time.Minute
+// Opts configures the watchdog.
+type Opts struct {
+ // TaskTimeout is the amount of time to allow a task to execute the
+ // same syscall without blocking before it's declared stuck.
+ TaskTimeout time.Duration
+
+ // TaskTimeoutAction indicates what action to take when a stuck tasks
+ // is detected.
+ TaskTimeoutAction Action
+
+ // StartupTimeout is the amount of time to allow between watchdog
+ // creation and calling watchdog.Start.
+ StartupTimeout time.Duration
+
+ // StartupTimeoutAction indicates what action to take when
+ // watchdog.Start is not called within the timeout.
+ StartupTimeoutAction Action
+}
+
+// DefaultOpts is a default set of options for the watchdog.
+var DefaultOpts = Opts{
+ // Task timeout.
+ TaskTimeout: 3 * time.Minute,
+ TaskTimeoutAction: LogWarning,
+
+ // Startup timeout.
+ StartupTimeout: 30 * time.Second,
+ StartupTimeoutAction: LogWarning,
+}
// descheduleThreshold is the amount of time scheduling needs to be off before the entire wait period
// is discounted from task's last update time. It's set high enough that small scheduling delays won't
@@ -61,6 +88,7 @@ type Action int
const (
// LogWarning logs warning message followed by stack trace.
LogWarning Action = iota
+
// Panic will do the same logging as LogWarning and panic().
Panic
)
@@ -80,17 +108,13 @@ func (a Action) String() string {
// Watchdog is the main watchdog class. It controls a goroutine that periodically
// analyses all tasks and reports if any of them appear to be stuck.
type Watchdog struct {
+ // Configuration options are embedded.
+ Opts
+
// period indicates how often to check all tasks. It's calculated based on
- // 'taskTimeout'.
+ // opts.TaskTimeout.
period time.Duration
- // taskTimeout is the amount of time to allow a task to execute the same syscall
- // without blocking before it's declared stuck.
- taskTimeout time.Duration
-
- // timeoutAction indicates what action to take when a stuck tasks is detected.
- timeoutAction Action
-
// k is where the tasks come from.
k *kernel.Kernel
@@ -113,8 +137,12 @@ type Watchdog struct {
// mu protects the fields below.
mu sync.Mutex
- // started is true if the watchdog has been started before.
- started bool
+ // running is true if the watchdog is running.
+ running bool
+
+ // startCalled is true if Start has ever been called. It remains true
+ // even if Stop is called.
+ startCalled bool
}
type offender struct {
@@ -122,58 +150,81 @@ type offender struct {
}
// New creates a new watchdog.
-func New(k *kernel.Kernel, taskTimeout time.Duration, a Action) *Watchdog {
- // 4 is arbitrary, just don't want to prolong 'taskTimeout' too much.
- period := taskTimeout / 4
- return &Watchdog{
- k: k,
- period: period,
- taskTimeout: taskTimeout,
- timeoutAction: a,
- offenders: make(map[*kernel.Task]*offender),
- stop: make(chan struct{}),
- done: make(chan struct{}),
+func New(k *kernel.Kernel, opts Opts) *Watchdog {
+ // 4 is arbitrary, just don't want to prolong 'TaskTimeout' too much.
+ period := opts.TaskTimeout / 4
+ w := &Watchdog{
+ Opts: opts,
+ k: k,
+ period: period,
+ offenders: make(map[*kernel.Task]*offender),
+ stop: make(chan struct{}),
+ done: make(chan struct{}),
+ }
+
+ // Handle StartupTimeout if it exists.
+ if w.StartupTimeout > 0 {
+ log.Infof("Watchdog waiting %v for startup", w.StartupTimeout)
+ go w.waitForStart() // S/R-SAFE: watchdog is stopped buring save and restarted after restore.
}
+
+ return w
}
// Start starts the watchdog.
func (w *Watchdog) Start() {
- if w.taskTimeout == 0 {
- log.Infof("Watchdog disabled")
- return
- }
-
w.mu.Lock()
defer w.mu.Unlock()
- if w.started {
+ w.startCalled = true
+
+ if w.running {
return
}
+ if w.TaskTimeout == 0 {
+ log.Infof("Watchdog task timeout disabled")
+ return
+ }
w.lastRun = w.k.MonotonicClock().Now()
- log.Infof("Starting watchdog, period: %v, timeout: %v, action: %v", w.period, w.taskTimeout, w.timeoutAction)
+ log.Infof("Starting watchdog, period: %v, timeout: %v, action: %v", w.period, w.TaskTimeout, w.TaskTimeoutAction)
go w.loop() // S/R-SAFE: watchdog is stopped during save and restarted after restore.
- w.started = true
+ w.running = true
}
// Stop requests the watchdog to stop and wait for it.
func (w *Watchdog) Stop() {
- if w.taskTimeout == 0 {
+ if w.TaskTimeout == 0 {
return
}
w.mu.Lock()
defer w.mu.Unlock()
- if !w.started {
+ if !w.running {
return
}
log.Infof("Stopping watchdog")
w.stop <- struct{}{}
<-w.done
- w.started = false
+ w.running = false
log.Infof("Watchdog stopped")
}
+// waitForStart waits for Start to be called and takes action if it does not
+// happen within the startup timeout.
+func (w *Watchdog) waitForStart() {
+ <-time.After(w.StartupTimeout)
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ if w.startCalled {
+ // We are fine.
+ return
+ }
+ var buf bytes.Buffer
+ buf.WriteString("Watchdog.Start() not called within %s:\n")
+ w.doAction(w.StartupTimeoutAction, false, &buf)
+}
+
// loop is the main watchdog routine. It only returns when 'Stop()' is called.
func (w *Watchdog) loop() {
// Loop until someone stops it.
@@ -202,7 +253,7 @@ func (w *Watchdog) runTurn() {
select {
case <-done:
- case <-time.After(w.taskTimeout):
+ case <-time.After(w.TaskTimeout):
// Report if the watchdog is not making progress.
// No one is wathching the watchdog watcher though.
w.reportStuckWatchdog()
@@ -231,12 +282,14 @@ func (w *Watchdog) runTurn() {
if tsched.State == kernel.TaskGoroutineRunningSys {
lastUpdateTime := ktime.FromNanoseconds(int64(tsched.Timestamp * uint64(linux.ClockTick)))
elapsed := now.Sub(lastUpdateTime) - discount
- if elapsed > w.taskTimeout {
+ if elapsed > w.TaskTimeout {
tc, ok := w.offenders[t]
if !ok {
// New stuck task detected.
//
- // TODO(b/65849403): Tasks blocked doing IO may be considered stuck in kernel.
+ // Note that tasks blocked doing IO may be considered stuck in kernel,
+ // unless they are surrounded b
+ // Task.UninterruptibleSleepStart/Finish.
tc = &offender{lastUpdateTime: lastUpdateTime}
stuckTasks.Increment()
newTaskFound = true
@@ -261,28 +314,34 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo
tid := w.k.TaskSet().Root.IDOfTask(t)
buf.WriteString(fmt.Sprintf("\tTask tid: %v (%#x), entered RunSys state %v ago.\n", tid, uint64(tid), now.Sub(o.lastUpdateTime)))
}
+
buf.WriteString("Search for '(*Task).run(0x..., 0x<tid>)' in the stack dump to find the offending goroutine")
- w.onStuckTask(newTaskFound, &buf)
+
+ // Dump stack only if a new task is detected or if it sometime has
+ // passed since the last time a stack dump was generated.
+ skipStack := newTaskFound || time.Since(w.lastStackDump) >= stackDumpSameTaskPeriod
+ w.doAction(w.TaskTimeoutAction, skipStack, &buf)
}
func (w *Watchdog) reportStuckWatchdog() {
var buf bytes.Buffer
buf.WriteString("Watchdog goroutine is stuck:\n")
- w.onStuckTask(true, &buf)
+ w.doAction(w.TaskTimeoutAction, false, &buf)
}
-func (w *Watchdog) onStuckTask(newTaskFound bool, msg *bytes.Buffer) {
- switch w.timeoutAction {
+// doAction will take the given action. If the action is LogWarnind and
+// skipStack is true, then the stack printing will be skipped.
+func (w *Watchdog) doAction(action Action, skipStack bool, msg *bytes.Buffer) {
+ switch action {
case LogWarning:
- // Dump stack only if a new task is detected or if it sometime has passed since
- // the last time a stack dump was generated.
- if !newTaskFound && time.Since(w.lastStackDump) < stackDumpSameTaskPeriod {
+ if skipStack {
msg.WriteString("\n...[stack dump skipped]...")
log.Warningf(msg.String())
- } else {
- log.TracebackAll(msg.String())
- w.lastStackDump = time.Now()
+ return
+
}
+ log.TracebackAll(msg.String())
+ w.lastStackDump = time.Now()
case Panic:
// Panic will skip over running tasks, which is likely the culprit here. So manually
@@ -301,5 +360,8 @@ func (w *Watchdog) onStuckTask(newTaskFound bool, msg *bytes.Buffer) {
case <-time.After(1 * time.Second):
}
panic(fmt.Sprintf("Stack for running G's are skipped while panicking.\n%s", msg.String()))
+ default:
+ panic(fmt.Sprintf("Unknown watchdog action %v", action))
+
}
}
diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go
index 8f5e60a25..acbf0229b 100644
--- a/pkg/sleep/sleep_unsafe.go
+++ b/pkg/sleep/sleep_unsafe.go
@@ -13,7 +13,7 @@
// limitations under the License.
// +build go1.11
-// +build !go1.14
+// +build !go1.15
// Check go:linkname function signatures when updating Go version.
diff --git a/pkg/state/decode.go b/pkg/state/decode.go
index 47e6b878a..590c241a3 100644
--- a/pkg/state/decode.go
+++ b/pkg/state/decode.go
@@ -16,6 +16,7 @@ package state
import (
"bytes"
+ "context"
"encoding/binary"
"errors"
"fmt"
@@ -133,6 +134,9 @@ func (os *objectState) findCycle() []*objectState {
// to ensure that all callbacks are executed, otherwise the callback graph was
// not acyclic.
type decodeState struct {
+ // ctx is the decode context.
+ ctx context.Context
+
// objectByID is the set of objects in progress.
objectsByID map[uint64]*objectState
diff --git a/pkg/state/encode.go b/pkg/state/encode.go
index 5d9409a45..c5118d3a9 100644
--- a/pkg/state/encode.go
+++ b/pkg/state/encode.go
@@ -16,6 +16,7 @@ package state
import (
"container/list"
+ "context"
"encoding/binary"
"fmt"
"io"
@@ -38,6 +39,9 @@ type queuedObject struct {
// The encoding process is a breadth-first traversal of the object graph. The
// inherent races and dependencies are much simpler than the decode case.
type encodeState struct {
+ // ctx is the encode context.
+ ctx context.Context
+
// lastID is the last object ID.
//
// See idsByObject for context. Because of the special zero encoding
diff --git a/pkg/state/map.go b/pkg/state/map.go
index 7e6fefed4..4f3ebb0da 100644
--- a/pkg/state/map.go
+++ b/pkg/state/map.go
@@ -15,6 +15,7 @@
package state
import (
+ "context"
"fmt"
"reflect"
"sort"
@@ -219,3 +220,13 @@ func (m Map) AfterLoad(fn func()) {
// data dependencies have been cleared.
m.os.callbacks = append(m.os.callbacks, fn)
}
+
+// Context returns the current context object.
+func (m Map) Context() context.Context {
+ if m.es != nil {
+ return m.es.ctx
+ } else if m.ds != nil {
+ return m.ds.ctx
+ }
+ return context.Background() // No context.
+}
diff --git a/pkg/state/object.proto b/pkg/state/object.proto
index 952289069..5ebcfb151 100644
--- a/pkg/state/object.proto
+++ b/pkg/state/object.proto
@@ -18,8 +18,8 @@ package gvisor.state.statefile;
// Slice is a slice value.
message Slice {
- uint32 length = 1;
- uint32 capacity = 2;
+ uint32 length = 1;
+ uint32 capacity = 2;
uint64 ref_value = 3;
}
@@ -30,13 +30,13 @@ message Array {
// Map is a map value.
message Map {
- repeated Object keys = 1;
+ repeated Object keys = 1;
repeated Object values = 2;
}
// Interface is an interface value.
message Interface {
- string type = 1;
+ string type = 1;
Object value = 2;
}
@@ -47,7 +47,7 @@ message Struct {
// Field encodes a single field.
message Field {
- string name = 1;
+ string name = 1;
Object value = 2;
}
@@ -113,28 +113,28 @@ message Float32s {
// Note that ref_value references an Object.id, below.
message Object {
oneof value {
- bool bool_value = 1;
- bytes string_value = 2;
- int64 int64_value = 3;
- uint64 uint64_value = 4;
- double double_value = 5;
- uint64 ref_value = 6;
- Slice slice_value = 7;
- Array array_value = 8;
- Interface interface_value = 9;
- Struct struct_value = 10;
- Map map_value = 11;
- bytes byte_array_value = 12;
- Uint16s uint16_array_value = 13;
- Uint32s uint32_array_value = 14;
- Uint64s uint64_array_value = 15;
- Uintptrs uintptr_array_value = 16;
- Int8s int8_array_value = 17;
- Int16s int16_array_value = 18;
- Int32s int32_array_value = 19;
- Int64s int64_array_value = 20;
- Bools bool_array_value = 21;
- Float64s float64_array_value = 22;
- Float32s float32_array_value = 23;
+ bool bool_value = 1;
+ bytes string_value = 2;
+ int64 int64_value = 3;
+ uint64 uint64_value = 4;
+ double double_value = 5;
+ uint64 ref_value = 6;
+ Slice slice_value = 7;
+ Array array_value = 8;
+ Interface interface_value = 9;
+ Struct struct_value = 10;
+ Map map_value = 11;
+ bytes byte_array_value = 12;
+ Uint16s uint16_array_value = 13;
+ Uint32s uint32_array_value = 14;
+ Uint64s uint64_array_value = 15;
+ Uintptrs uintptr_array_value = 16;
+ Int8s int8_array_value = 17;
+ Int16s int16_array_value = 18;
+ Int32s int32_array_value = 19;
+ Int64s int64_array_value = 20;
+ Bools bool_array_value = 21;
+ Float64s float64_array_value = 22;
+ Float32s float32_array_value = 23;
}
}
diff --git a/pkg/state/state.go b/pkg/state/state.go
index d408ff84a..dbe507ab4 100644
--- a/pkg/state/state.go
+++ b/pkg/state/state.go
@@ -50,6 +50,7 @@
package state
import (
+ "context"
"fmt"
"io"
"reflect"
@@ -86,9 +87,10 @@ func UnwrapErrState(err error) error {
}
// Save saves the given object state.
-func Save(w io.Writer, rootPtr interface{}, stats *Stats) error {
+func Save(ctx context.Context, w io.Writer, rootPtr interface{}, stats *Stats) error {
// Create the encoding state.
es := &encodeState{
+ ctx: ctx,
idsByObject: make(map[uintptr]uint64),
w: w,
stats: stats,
@@ -101,9 +103,10 @@ func Save(w io.Writer, rootPtr interface{}, stats *Stats) error {
}
// Load loads a checkpoint.
-func Load(r io.Reader, rootPtr interface{}, stats *Stats) error {
+func Load(ctx context.Context, r io.Reader, rootPtr interface{}, stats *Stats) error {
// Create the decoding state.
ds := &decodeState{
+ ctx: ctx,
objectsByID: make(map[uint64]*objectState),
deferred: make(map[uint64]*pb.Object),
r: r,
diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go
index 7c24bbcda..d7221e9e8 100644
--- a/pkg/state/state_test.go
+++ b/pkg/state/state_test.go
@@ -16,6 +16,7 @@ package state
import (
"bytes"
+ "context"
"io/ioutil"
"math"
"reflect"
@@ -46,7 +47,7 @@ func runTest(t *testing.T, tests []TestCase) {
saveBuffer := &bytes.Buffer{}
saveObjectPtr := reflect.New(reflect.TypeOf(root))
saveObjectPtr.Elem().Set(reflect.ValueOf(root))
- if err := Save(saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail {
+ if err := Save(context.Background(), saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail {
t.Errorf(" FAIL: Save failed unexpectedly: %v", err)
continue
} else if err != nil {
@@ -56,7 +57,7 @@ func runTest(t *testing.T, tests []TestCase) {
// Load a new copy of the object.
loadObjectPtr := reflect.New(reflect.TypeOf(root))
- if err := Load(bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail {
+ if err := Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail {
t.Errorf(" FAIL: Load failed unexpectedly: %v", err)
continue
} else if err != nil {
@@ -624,7 +625,7 @@ func BenchmarkEncoding(b *testing.B) {
bs := buildObject(b.N)
var stats Stats
b.StartTimer()
- if err := Save(ioutil.Discard, bs, &stats); err != nil {
+ if err := Save(context.Background(), ioutil.Discard, bs, &stats); err != nil {
b.Errorf("save failed: %v", err)
}
b.StopTimer()
@@ -638,12 +639,12 @@ func BenchmarkDecoding(b *testing.B) {
bs := buildObject(b.N)
var newBS benchStruct
buf := &bytes.Buffer{}
- if err := Save(buf, bs, nil); err != nil {
+ if err := Save(context.Background(), buf, bs, nil); err != nil {
b.Errorf("save failed: %v", err)
}
var stats Stats
b.StartTimer()
- if err := Load(buf, &newBS, &stats); err != nil {
+ if err := Load(context.Background(), buf, &newBS, &stats); err != nil {
b.Errorf("load failed: %v", err)
}
b.StopTimer()
diff --git a/pkg/syncutil/BUILD b/pkg/syncutil/BUILD
new file mode 100644
index 000000000..cb1f41628
--- /dev/null
+++ b/pkg/syncutil/BUILD
@@ -0,0 +1,52 @@
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template")
+
+package(
+ default_visibility = ["//:sandbox"],
+ licenses = ["notice"],
+)
+
+exports_files(["LICENSE"])
+
+go_template(
+ name = "generic_atomicptr",
+ srcs = ["atomicptr_unsafe.go"],
+ types = [
+ "Value",
+ ],
+)
+
+go_template(
+ name = "generic_seqatomic",
+ srcs = ["seqatomic_unsafe.go"],
+ types = [
+ "Value",
+ ],
+ deps = [
+ ":sync",
+ ],
+)
+
+go_library(
+ name = "syncutil",
+ srcs = [
+ "downgradable_rwmutex_unsafe.go",
+ "memmove_unsafe.go",
+ "norace_unsafe.go",
+ "race_unsafe.go",
+ "seqcount.go",
+ "syncutil.go",
+ ],
+ importpath = "gvisor.dev/gvisor/pkg/syncutil",
+)
+
+go_test(
+ name = "syncutil_test",
+ size = "small",
+ srcs = [
+ "downgradable_rwmutex_test.go",
+ "seqcount_test.go",
+ ],
+ embed = [":syncutil"],
+)
diff --git a/pkg/syncutil/LICENSE b/pkg/syncutil/LICENSE
new file mode 100644
index 000000000..6a66aea5e
--- /dev/null
+++ b/pkg/syncutil/LICENSE
@@ -0,0 +1,27 @@
+Copyright (c) 2009 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/pkg/syncutil/README.md b/pkg/syncutil/README.md
new file mode 100644
index 000000000..2183c4e20
--- /dev/null
+++ b/pkg/syncutil/README.md
@@ -0,0 +1,5 @@
+# Syncutil
+
+This package provides additional synchronization primitives not provided by the
+Go stdlib 'sync' package. It is partially derived from the upstream 'sync'
+package from go1.10.
diff --git a/pkg/syncutil/atomicptr_unsafe.go b/pkg/syncutil/atomicptr_unsafe.go
new file mode 100644
index 000000000..525c4beed
--- /dev/null
+++ b/pkg/syncutil/atomicptr_unsafe.go
@@ -0,0 +1,47 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package template doesn't exist. This file must be instantiated using the
+// go_template_instance rule in tools/go_generics/defs.bzl.
+package template
+
+import (
+ "sync/atomic"
+ "unsafe"
+)
+
+// Value is a required type parameter.
+type Value struct{}
+
+// An AtomicPtr is a pointer to a value of type Value that can be atomically
+// loaded and stored. The zero value of an AtomicPtr represents nil.
+//
+// Note that copying AtomicPtr by value performs a non-atomic read of the
+// stored pointer, which is unsafe if Store() can be called concurrently; in
+// this case, do `dst.Store(src.Load())` instead.
+//
+// +stateify savable
+type AtomicPtr struct {
+ ptr unsafe.Pointer `state:".(*Value)"`
+}
+
+func (p *AtomicPtr) savePtr() *Value {
+ return p.Load()
+}
+
+func (p *AtomicPtr) loadPtr(v *Value) {
+ p.Store(v)
+}
+
+// Load returns the value set by the most recent Store. It returns nil if there
+// has been no previous call to Store.
+func (p *AtomicPtr) Load() *Value {
+ return (*Value)(atomic.LoadPointer(&p.ptr))
+}
+
+// Store sets the value returned by Load to x.
+func (p *AtomicPtr) Store(x *Value) {
+ atomic.StorePointer(&p.ptr, (unsafe.Pointer)(x))
+}
diff --git a/pkg/syncutil/atomicptrtest/BUILD b/pkg/syncutil/atomicptrtest/BUILD
new file mode 100644
index 000000000..63f411a90
--- /dev/null
+++ b/pkg/syncutil/atomicptrtest/BUILD
@@ -0,0 +1,29 @@
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "atomicptr_int",
+ out = "atomicptr_int_unsafe.go",
+ package = "atomicptr",
+ suffix = "Int",
+ template = "//pkg/syncutil:generic_atomicptr",
+ types = {
+ "Value": "int",
+ },
+)
+
+go_library(
+ name = "atomicptr",
+ srcs = ["atomicptr_int_unsafe.go"],
+ importpath = "gvisor.dev/gvisor/pkg/syncutil/atomicptr",
+)
+
+go_test(
+ name = "atomicptr_test",
+ size = "small",
+ srcs = ["atomicptr_test.go"],
+ embed = [":atomicptr"],
+)
diff --git a/pkg/syncutil/atomicptrtest/atomicptr_test.go b/pkg/syncutil/atomicptrtest/atomicptr_test.go
new file mode 100644
index 000000000..8fdc5112e
--- /dev/null
+++ b/pkg/syncutil/atomicptrtest/atomicptr_test.go
@@ -0,0 +1,31 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package atomicptr
+
+import (
+ "testing"
+)
+
+func newInt(val int) *int {
+ return &val
+}
+
+func TestAtomicPtr(t *testing.T) {
+ var p AtomicPtrInt
+ if got := p.Load(); got != nil {
+ t.Errorf("initial value is %p (%v), wanted nil", got, got)
+ }
+ want := newInt(42)
+ p.Store(want)
+ if got := p.Load(); got != want {
+ t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want)
+ }
+ want = newInt(100)
+ p.Store(want)
+ if got := p.Load(); got != want {
+ t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want)
+ }
+}
diff --git a/pkg/syncutil/downgradable_rwmutex_test.go b/pkg/syncutil/downgradable_rwmutex_test.go
new file mode 100644
index 000000000..ffaf7ecc7
--- /dev/null
+++ b/pkg/syncutil/downgradable_rwmutex_test.go
@@ -0,0 +1,150 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Copyright 2019 The gVisor Authors.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// GOMAXPROCS=10 go test
+
+// Copy/pasted from the standard library's sync/rwmutex_test.go, except for the
+// addition of downgradingWriter and the renaming of num_iterations to
+// numIterations to shut up Golint.
+
+package syncutil
+
+import (
+ "fmt"
+ "runtime"
+ "sync/atomic"
+ "testing"
+)
+
+func parallelReader(m *DowngradableRWMutex, clocked, cunlock, cdone chan bool) {
+ m.RLock()
+ clocked <- true
+ <-cunlock
+ m.RUnlock()
+ cdone <- true
+}
+
+func doTestParallelReaders(numReaders, gomaxprocs int) {
+ runtime.GOMAXPROCS(gomaxprocs)
+ var m DowngradableRWMutex
+ clocked := make(chan bool)
+ cunlock := make(chan bool)
+ cdone := make(chan bool)
+ for i := 0; i < numReaders; i++ {
+ go parallelReader(&m, clocked, cunlock, cdone)
+ }
+ // Wait for all parallel RLock()s to succeed.
+ for i := 0; i < numReaders; i++ {
+ <-clocked
+ }
+ for i := 0; i < numReaders; i++ {
+ cunlock <- true
+ }
+ // Wait for the goroutines to finish.
+ for i := 0; i < numReaders; i++ {
+ <-cdone
+ }
+}
+
+func TestParallelReaders(t *testing.T) {
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1))
+ doTestParallelReaders(1, 4)
+ doTestParallelReaders(3, 4)
+ doTestParallelReaders(4, 2)
+}
+
+func reader(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) {
+ for i := 0; i < numIterations; i++ {
+ rwm.RLock()
+ n := atomic.AddInt32(activity, 1)
+ if n < 1 || n >= 10000 {
+ panic(fmt.Sprintf("wlock(%d)\n", n))
+ }
+ for i := 0; i < 100; i++ {
+ }
+ atomic.AddInt32(activity, -1)
+ rwm.RUnlock()
+ }
+ cdone <- true
+}
+
+func writer(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) {
+ for i := 0; i < numIterations; i++ {
+ rwm.Lock()
+ n := atomic.AddInt32(activity, 10000)
+ if n != 10000 {
+ panic(fmt.Sprintf("wlock(%d)\n", n))
+ }
+ for i := 0; i < 100; i++ {
+ }
+ atomic.AddInt32(activity, -10000)
+ rwm.Unlock()
+ }
+ cdone <- true
+}
+
+func downgradingWriter(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) {
+ for i := 0; i < numIterations; i++ {
+ rwm.Lock()
+ n := atomic.AddInt32(activity, 10000)
+ if n != 10000 {
+ panic(fmt.Sprintf("wlock(%d)\n", n))
+ }
+ for i := 0; i < 100; i++ {
+ }
+ atomic.AddInt32(activity, -10000)
+ rwm.DowngradeLock()
+ n = atomic.AddInt32(activity, 1)
+ if n < 1 || n >= 10000 {
+ panic(fmt.Sprintf("wlock(%d)\n", n))
+ }
+ for i := 0; i < 100; i++ {
+ }
+ n = atomic.AddInt32(activity, -1)
+ rwm.RUnlock()
+ }
+ cdone <- true
+}
+
+func HammerDowngradableRWMutex(gomaxprocs, numReaders, numIterations int) {
+ runtime.GOMAXPROCS(gomaxprocs)
+ // Number of active readers + 10000 * number of active writers.
+ var activity int32
+ var rwm DowngradableRWMutex
+ cdone := make(chan bool)
+ go writer(&rwm, numIterations, &activity, cdone)
+ go downgradingWriter(&rwm, numIterations, &activity, cdone)
+ var i int
+ for i = 0; i < numReaders/2; i++ {
+ go reader(&rwm, numIterations, &activity, cdone)
+ }
+ go writer(&rwm, numIterations, &activity, cdone)
+ go downgradingWriter(&rwm, numIterations, &activity, cdone)
+ for ; i < numReaders; i++ {
+ go reader(&rwm, numIterations, &activity, cdone)
+ }
+ // Wait for the 4 writers and all readers to finish.
+ for i := 0; i < 4+numReaders; i++ {
+ <-cdone
+ }
+}
+
+func TestDowngradableRWMutex(t *testing.T) {
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1))
+ n := 1000
+ if testing.Short() {
+ n = 5
+ }
+ HammerDowngradableRWMutex(1, 1, n)
+ HammerDowngradableRWMutex(1, 3, n)
+ HammerDowngradableRWMutex(1, 10, n)
+ HammerDowngradableRWMutex(4, 1, n)
+ HammerDowngradableRWMutex(4, 3, n)
+ HammerDowngradableRWMutex(4, 10, n)
+ HammerDowngradableRWMutex(10, 1, n)
+ HammerDowngradableRWMutex(10, 3, n)
+ HammerDowngradableRWMutex(10, 10, n)
+ HammerDowngradableRWMutex(10, 5, n)
+}
diff --git a/pkg/syncutil/downgradable_rwmutex_unsafe.go b/pkg/syncutil/downgradable_rwmutex_unsafe.go
new file mode 100644
index 000000000..51e11555d
--- /dev/null
+++ b/pkg/syncutil/downgradable_rwmutex_unsafe.go
@@ -0,0 +1,146 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Copyright 2019 The gVisor Authors.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.13
+// +build !go1.15
+
+// Check go:linkname function signatures when updating Go version.
+
+// This is mostly copied from the standard library's sync/rwmutex.go.
+//
+// Happens-before relationships indicated to the race detector:
+// - Unlock -> Lock (via writerSem)
+// - Unlock -> RLock (via readerSem)
+// - RUnlock -> Lock (via writerSem)
+// - DowngradeLock -> RLock (via readerSem)
+
+package syncutil
+
+import (
+ "sync"
+ "sync/atomic"
+ "unsafe"
+)
+
+//go:linkname runtimeSemacquire sync.runtime_Semacquire
+func runtimeSemacquire(s *uint32)
+
+//go:linkname runtimeSemrelease sync.runtime_Semrelease
+func runtimeSemrelease(s *uint32, handoff bool, skipframes int)
+
+// DowngradableRWMutex is identical to sync.RWMutex, but adds the DowngradeLock
+// method.
+type DowngradableRWMutex struct {
+ w sync.Mutex // held if there are pending writers
+ writerSem uint32 // semaphore for writers to wait for completing readers
+ readerSem uint32 // semaphore for readers to wait for completing writers
+ readerCount int32 // number of pending readers
+ readerWait int32 // number of departing readers
+}
+
+const rwmutexMaxReaders = 1 << 30
+
+// RLock locks rw for reading.
+func (rw *DowngradableRWMutex) RLock() {
+ if RaceEnabled {
+ RaceDisable()
+ }
+ if atomic.AddInt32(&rw.readerCount, 1) < 0 {
+ // A writer is pending, wait for it.
+ runtimeSemacquire(&rw.readerSem)
+ }
+ if RaceEnabled {
+ RaceEnable()
+ RaceAcquire(unsafe.Pointer(&rw.readerSem))
+ }
+}
+
+// RUnlock undoes a single RLock call.
+func (rw *DowngradableRWMutex) RUnlock() {
+ if RaceEnabled {
+ RaceReleaseMerge(unsafe.Pointer(&rw.writerSem))
+ RaceDisable()
+ }
+ if r := atomic.AddInt32(&rw.readerCount, -1); r < 0 {
+ if r+1 == 0 || r+1 == -rwmutexMaxReaders {
+ panic("RUnlock of unlocked DowngradableRWMutex")
+ }
+ // A writer is pending.
+ if atomic.AddInt32(&rw.readerWait, -1) == 0 {
+ // The last reader unblocks the writer.
+ runtimeSemrelease(&rw.writerSem, false, 0)
+ }
+ }
+ if RaceEnabled {
+ RaceEnable()
+ }
+}
+
+// Lock locks rw for writing.
+func (rw *DowngradableRWMutex) Lock() {
+ if RaceEnabled {
+ RaceDisable()
+ }
+ // First, resolve competition with other writers.
+ rw.w.Lock()
+ // Announce to readers there is a pending writer.
+ r := atomic.AddInt32(&rw.readerCount, -rwmutexMaxReaders) + rwmutexMaxReaders
+ // Wait for active readers.
+ if r != 0 && atomic.AddInt32(&rw.readerWait, r) != 0 {
+ runtimeSemacquire(&rw.writerSem)
+ }
+ if RaceEnabled {
+ RaceEnable()
+ RaceAcquire(unsafe.Pointer(&rw.writerSem))
+ }
+}
+
+// Unlock unlocks rw for writing.
+func (rw *DowngradableRWMutex) Unlock() {
+ if RaceEnabled {
+ RaceRelease(unsafe.Pointer(&rw.writerSem))
+ RaceRelease(unsafe.Pointer(&rw.readerSem))
+ RaceDisable()
+ }
+ // Announce to readers there is no active writer.
+ r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders)
+ if r >= rwmutexMaxReaders {
+ panic("Unlock of unlocked DowngradableRWMutex")
+ }
+ // Unblock blocked readers, if any.
+ for i := 0; i < int(r); i++ {
+ runtimeSemrelease(&rw.readerSem, false, 0)
+ }
+ // Allow other writers to proceed.
+ rw.w.Unlock()
+ if RaceEnabled {
+ RaceEnable()
+ }
+}
+
+// DowngradeLock atomically unlocks rw for writing and locks it for reading.
+func (rw *DowngradableRWMutex) DowngradeLock() {
+ if RaceEnabled {
+ RaceRelease(unsafe.Pointer(&rw.readerSem))
+ RaceDisable()
+ }
+ // Announce to readers there is no active writer and one additional reader.
+ r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders+1)
+ if r >= rwmutexMaxReaders+1 {
+ panic("DowngradeLock of unlocked DowngradableRWMutex")
+ }
+ // Unblock blocked readers, if any. Note that this loop starts as 1 since r
+ // includes this goroutine.
+ for i := 1; i < int(r); i++ {
+ runtimeSemrelease(&rw.readerSem, false, 0)
+ }
+ // Allow other writers to proceed to rw.w.Lock(). Note that they will still
+ // block on rw.writerSem since at least this reader exists, such that
+ // DowngradeLock() is atomic with the previous write lock.
+ rw.w.Unlock()
+ if RaceEnabled {
+ RaceEnable()
+ }
+}
diff --git a/pkg/syncutil/memmove_unsafe.go b/pkg/syncutil/memmove_unsafe.go
new file mode 100644
index 000000000..348675baa
--- /dev/null
+++ b/pkg/syncutil/memmove_unsafe.go
@@ -0,0 +1,28 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.12
+// +build !go1.15
+
+// Check go:linkname function signatures when updating Go version.
+
+package syncutil
+
+import (
+ "unsafe"
+)
+
+//go:linkname memmove runtime.memmove
+//go:noescape
+func memmove(to, from unsafe.Pointer, n uintptr)
+
+// Memmove is exported for SeqAtomicLoad/SeqAtomicTryLoad<T>, which can't
+// define it because go_generics can't update the go:linkname annotation.
+// Furthermore, go:linkname silently doesn't work if the local name is exported
+// (this is of course undocumented), which is why this indirection is
+// necessary.
+func Memmove(to, from unsafe.Pointer, n uintptr) {
+ memmove(to, from, n)
+}
diff --git a/pkg/syncutil/norace_unsafe.go b/pkg/syncutil/norace_unsafe.go
new file mode 100644
index 000000000..0a0a9deda
--- /dev/null
+++ b/pkg/syncutil/norace_unsafe.go
@@ -0,0 +1,35 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !race
+
+package syncutil
+
+import (
+ "unsafe"
+)
+
+// RaceEnabled is true if the Go data race detector is enabled.
+const RaceEnabled = false
+
+// RaceDisable has the same semantics as runtime.RaceDisable.
+func RaceDisable() {
+}
+
+// RaceEnable has the same semantics as runtime.RaceEnable.
+func RaceEnable() {
+}
+
+// RaceAcquire has the same semantics as runtime.RaceAcquire.
+func RaceAcquire(addr unsafe.Pointer) {
+}
+
+// RaceRelease has the same semantics as runtime.RaceRelease.
+func RaceRelease(addr unsafe.Pointer) {
+}
+
+// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge.
+func RaceReleaseMerge(addr unsafe.Pointer) {
+}
diff --git a/pkg/syncutil/race_unsafe.go b/pkg/syncutil/race_unsafe.go
new file mode 100644
index 000000000..206067ec1
--- /dev/null
+++ b/pkg/syncutil/race_unsafe.go
@@ -0,0 +1,41 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build race
+
+package syncutil
+
+import (
+ "runtime"
+ "unsafe"
+)
+
+// RaceEnabled is true if the Go data race detector is enabled.
+const RaceEnabled = true
+
+// RaceDisable has the same semantics as runtime.RaceDisable.
+func RaceDisable() {
+ runtime.RaceDisable()
+}
+
+// RaceEnable has the same semantics as runtime.RaceEnable.
+func RaceEnable() {
+ runtime.RaceEnable()
+}
+
+// RaceAcquire has the same semantics as runtime.RaceAcquire.
+func RaceAcquire(addr unsafe.Pointer) {
+ runtime.RaceAcquire(addr)
+}
+
+// RaceRelease has the same semantics as runtime.RaceRelease.
+func RaceRelease(addr unsafe.Pointer) {
+ runtime.RaceRelease(addr)
+}
+
+// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge.
+func RaceReleaseMerge(addr unsafe.Pointer) {
+ runtime.RaceReleaseMerge(addr)
+}
diff --git a/pkg/syncutil/seqatomic_unsafe.go b/pkg/syncutil/seqatomic_unsafe.go
new file mode 100644
index 000000000..cb6d2eb22
--- /dev/null
+++ b/pkg/syncutil/seqatomic_unsafe.go
@@ -0,0 +1,72 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package template doesn't exist. This file must be instantiated using the
+// go_template_instance rule in tools/go_generics/defs.bzl.
+package template
+
+import (
+ "fmt"
+ "reflect"
+ "strings"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/syncutil"
+)
+
+// Value is a required type parameter.
+//
+// Value must not contain any pointers, including interface objects, function
+// objects, slices, maps, channels, unsafe.Pointer, and arrays or structs
+// containing any of the above. An init() function will panic if this property
+// does not hold.
+type Value struct{}
+
+// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race
+// with any writer critical sections in sc.
+func SeqAtomicLoad(sc *syncutil.SeqCount, ptr *Value) Value {
+ // This function doesn't use SeqAtomicTryLoad because doing so is
+ // measurably, significantly (~20%) slower; Go is awful at inlining.
+ var val Value
+ for {
+ epoch := sc.BeginRead()
+ if syncutil.RaceEnabled {
+ // runtime.RaceDisable() doesn't actually stop the race detector,
+ // so it can't help us here. Instead, call runtime.memmove
+ // directly, which is not instrumented by the race detector.
+ syncutil.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
+ } else {
+ // This is ~40% faster for short reads than going through memmove.
+ val = *ptr
+ }
+ if sc.ReadOk(epoch) {
+ break
+ }
+ }
+ return val
+}
+
+// SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section
+// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read
+// would race with a writer critical section, SeqAtomicTryLoad returns
+// (unspecified, false).
+func SeqAtomicTryLoad(sc *syncutil.SeqCount, epoch syncutil.SeqCountEpoch, ptr *Value) (Value, bool) {
+ var val Value
+ if syncutil.RaceEnabled {
+ syncutil.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
+ } else {
+ val = *ptr
+ }
+ return val, sc.ReadOk(epoch)
+}
+
+func init() {
+ var val Value
+ typ := reflect.TypeOf(val)
+ name := typ.Name()
+ if ptrs := syncutil.PointersInType(typ, name); len(ptrs) != 0 {
+ panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n")))
+ }
+}
diff --git a/pkg/syncutil/seqatomictest/BUILD b/pkg/syncutil/seqatomictest/BUILD
new file mode 100644
index 000000000..ba18f3238
--- /dev/null
+++ b/pkg/syncutil/seqatomictest/BUILD
@@ -0,0 +1,35 @@
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "seqatomic_int",
+ out = "seqatomic_int_unsafe.go",
+ package = "seqatomic",
+ suffix = "Int",
+ template = "//pkg/syncutil:generic_seqatomic",
+ types = {
+ "Value": "int",
+ },
+)
+
+go_library(
+ name = "seqatomic",
+ srcs = ["seqatomic_int_unsafe.go"],
+ importpath = "gvisor.dev/gvisor/pkg/syncutil/seqatomic",
+ deps = [
+ "//pkg/syncutil",
+ ],
+)
+
+go_test(
+ name = "seqatomic_test",
+ size = "small",
+ srcs = ["seqatomic_test.go"],
+ embed = [":seqatomic"],
+ deps = [
+ "//pkg/syncutil",
+ ],
+)
diff --git a/pkg/syncutil/seqatomictest/seqatomic_test.go b/pkg/syncutil/seqatomictest/seqatomic_test.go
new file mode 100644
index 000000000..b0db44999
--- /dev/null
+++ b/pkg/syncutil/seqatomictest/seqatomic_test.go
@@ -0,0 +1,132 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seqatomic
+
+import (
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/syncutil"
+)
+
+func TestSeqAtomicLoadUncontended(t *testing.T) {
+ var seq syncutil.SeqCount
+ const want = 1
+ data := want
+ if got := SeqAtomicLoadInt(&seq, &data); got != want {
+ t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want)
+ }
+}
+
+func TestSeqAtomicLoadAfterWrite(t *testing.T) {
+ var seq syncutil.SeqCount
+ var data int
+ const want = 1
+ seq.BeginWrite()
+ data = want
+ seq.EndWrite()
+ if got := SeqAtomicLoadInt(&seq, &data); got != want {
+ t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want)
+ }
+}
+
+func TestSeqAtomicLoadDuringWrite(t *testing.T) {
+ var seq syncutil.SeqCount
+ var data int
+ const want = 1
+ seq.BeginWrite()
+ go func() {
+ time.Sleep(time.Second)
+ data = want
+ seq.EndWrite()
+ }()
+ if got := SeqAtomicLoadInt(&seq, &data); got != want {
+ t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want)
+ }
+}
+
+func TestSeqAtomicTryLoadUncontended(t *testing.T) {
+ var seq syncutil.SeqCount
+ const want = 1
+ data := want
+ epoch := seq.BeginRead()
+ if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want {
+ t.Errorf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want)
+ }
+}
+
+func TestSeqAtomicTryLoadDuringWrite(t *testing.T) {
+ var seq syncutil.SeqCount
+ var data int
+ epoch := seq.BeginRead()
+ seq.BeginWrite()
+ if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok {
+ t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got)
+ }
+ seq.EndWrite()
+}
+
+func TestSeqAtomicTryLoadAfterWrite(t *testing.T) {
+ var seq syncutil.SeqCount
+ var data int
+ epoch := seq.BeginRead()
+ seq.BeginWrite()
+ seq.EndWrite()
+ if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok {
+ t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got)
+ }
+}
+
+func BenchmarkSeqAtomicLoadIntUncontended(b *testing.B) {
+ var seq syncutil.SeqCount
+ const want = 42
+ data := want
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ if got := SeqAtomicLoadInt(&seq, &data); got != want {
+ b.Fatalf("SeqAtomicLoadInt: got %v, wanted %v", got, want)
+ }
+ }
+ })
+}
+
+func BenchmarkSeqAtomicTryLoadIntUncontended(b *testing.B) {
+ var seq syncutil.SeqCount
+ const want = 42
+ data := want
+ b.RunParallel(func(pb *testing.PB) {
+ epoch := seq.BeginRead()
+ for pb.Next() {
+ if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want {
+ b.Fatalf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want)
+ }
+ }
+ })
+}
+
+// For comparison:
+func BenchmarkAtomicValueLoadIntUncontended(b *testing.B) {
+ var a atomic.Value
+ const want = 42
+ a.Store(int(want))
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ if got := a.Load().(int); got != want {
+ b.Fatalf("atomic.Value.Load: got %v, wanted %v", got, want)
+ }
+ }
+ })
+}
diff --git a/pkg/syncutil/seqcount.go b/pkg/syncutil/seqcount.go
new file mode 100644
index 000000000..11d8dbfaa
--- /dev/null
+++ b/pkg/syncutil/seqcount.go
@@ -0,0 +1,149 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package syncutil
+
+import (
+ "fmt"
+ "reflect"
+ "runtime"
+ "sync/atomic"
+)
+
+// SeqCount is a synchronization primitive for optimistic reader/writer
+// synchronization in cases where readers can work with stale data and
+// therefore do not need to block writers.
+//
+// Compared to sync/atomic.Value:
+//
+// - Mutation of SeqCount-protected data does not require memory allocation,
+// whereas atomic.Value generally does. This is a significant advantage when
+// writes are common.
+//
+// - Atomic reads of SeqCount-protected data require copying. This is a
+// disadvantage when atomic reads are common.
+//
+// - SeqCount may be more flexible: correct use of SeqCount.ReadOk allows other
+// operations to be made atomic with reads of SeqCount-protected data.
+//
+// - SeqCount may be less flexible: as of this writing, SeqCount-protected data
+// cannot include pointers.
+//
+// - SeqCount is more cumbersome to use; atomic reads of SeqCount-protected
+// data require instantiating function templates using go_generics (see
+// seqatomic.go).
+type SeqCount struct {
+ // epoch is incremented by BeginWrite and EndWrite, such that epoch is odd
+ // if a writer critical section is active, and a read from data protected
+ // by this SeqCount is atomic iff epoch is the same even value before and
+ // after the read.
+ epoch uint32
+}
+
+// SeqCountEpoch tracks writer critical sections in a SeqCount.
+type SeqCountEpoch struct {
+ val uint32
+}
+
+// We assume that:
+//
+// - All functions in sync/atomic that perform a memory read are at least a
+// read fence: memory reads before calls to such functions cannot be reordered
+// after the call, and memory reads after calls to such functions cannot be
+// reordered before the call, even if those reads do not use sync/atomic.
+//
+// - All functions in sync/atomic that perform a memory write are at least a
+// write fence: memory writes before calls to such functions cannot be
+// reordered after the call, and memory writes after calls to such functions
+// cannot be reordered before the call, even if those writes do not use
+// sync/atomic.
+//
+// As of this writing, the Go memory model completely fails to describe
+// sync/atomic, but these properties are implied by
+// https://groups.google.com/forum/#!topic/golang-nuts/7EnEhM3U7B8.
+
+// BeginRead indicates the beginning of a reader critical section. Reader
+// critical sections DO NOT BLOCK writer critical sections, so operations in a
+// reader critical section MAY RACE with writer critical sections. Races are
+// detected by ReadOk at the end of the reader critical section. Thus, the
+// low-level structure of readers is generally:
+//
+// for {
+// epoch := seq.BeginRead()
+// // do something idempotent with seq-protected data
+// if seq.ReadOk(epoch) {
+// break
+// }
+// }
+//
+// However, since reader critical sections may race with writer critical
+// sections, the Go race detector will (accurately) flag data races in readers
+// using this pattern. Most users of SeqCount will need to use the
+// SeqAtomicLoad function template in seqatomic.go.
+func (s *SeqCount) BeginRead() SeqCountEpoch {
+ epoch := atomic.LoadUint32(&s.epoch)
+ for epoch&1 != 0 {
+ runtime.Gosched()
+ epoch = atomic.LoadUint32(&s.epoch)
+ }
+ return SeqCountEpoch{epoch}
+}
+
+// ReadOk returns true if the reader critical section initiated by a previous
+// call to BeginRead() that returned epoch did not race with any writer critical
+// sections.
+//
+// ReadOk may be called any number of times during a reader critical section.
+// Reader critical sections do not need to be explicitly terminated; the last
+// call to ReadOk is implicitly the end of the reader critical section.
+func (s *SeqCount) ReadOk(epoch SeqCountEpoch) bool {
+ return atomic.LoadUint32(&s.epoch) == epoch.val
+}
+
+// BeginWrite indicates the beginning of a writer critical section.
+//
+// SeqCount does not support concurrent writer critical sections; clients with
+// concurrent writers must synchronize them using e.g. sync.Mutex.
+func (s *SeqCount) BeginWrite() {
+ if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 == 0 {
+ panic("SeqCount.BeginWrite during writer critical section")
+ }
+}
+
+// EndWrite ends the effect of a preceding BeginWrite.
+func (s *SeqCount) EndWrite() {
+ if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 != 0 {
+ panic("SeqCount.EndWrite outside writer critical section")
+ }
+}
+
+// PointersInType returns a list of pointers reachable from values named
+// valName of the given type.
+//
+// PointersInType is not exhaustive, but it is guaranteed that if typ contains
+// at least one pointer, then PointersInTypeOf returns a non-empty list.
+func PointersInType(typ reflect.Type, valName string) []string {
+ switch kind := typ.Kind(); kind {
+ case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
+ return nil
+
+ case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.String, reflect.UnsafePointer:
+ return []string{valName}
+
+ case reflect.Array:
+ return PointersInType(typ.Elem(), valName+"[]")
+
+ case reflect.Struct:
+ var ptrs []string
+ for i, n := 0, typ.NumField(); i < n; i++ {
+ field := typ.Field(i)
+ ptrs = append(ptrs, PointersInType(field.Type, fmt.Sprintf("%s.%s", valName, field.Name))...)
+ }
+ return ptrs
+
+ default:
+ return []string{fmt.Sprintf("%s (of type %s with unknown kind %s)", valName, typ, kind)}
+ }
+}
diff --git a/pkg/syncutil/seqcount_test.go b/pkg/syncutil/seqcount_test.go
new file mode 100644
index 000000000..14d6aedea
--- /dev/null
+++ b/pkg/syncutil/seqcount_test.go
@@ -0,0 +1,153 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package syncutil
+
+import (
+ "reflect"
+ "testing"
+ "time"
+)
+
+func TestSeqCountWriteUncontended(t *testing.T) {
+ var seq SeqCount
+ seq.BeginWrite()
+ seq.EndWrite()
+}
+
+func TestSeqCountReadUncontended(t *testing.T) {
+ var seq SeqCount
+ epoch := seq.BeginRead()
+ if !seq.ReadOk(epoch) {
+ t.Errorf("ReadOk: got false, wanted true")
+ }
+}
+
+func TestSeqCountBeginReadAfterWrite(t *testing.T) {
+ var seq SeqCount
+ var data int32
+ const want = 1
+ seq.BeginWrite()
+ data = want
+ seq.EndWrite()
+ epoch := seq.BeginRead()
+ if data != want {
+ t.Errorf("Reader: got %v, wanted %v", data, want)
+ }
+ if !seq.ReadOk(epoch) {
+ t.Errorf("ReadOk: got false, wanted true")
+ }
+}
+
+func TestSeqCountBeginReadDuringWrite(t *testing.T) {
+ var seq SeqCount
+ var data int
+ const want = 1
+ seq.BeginWrite()
+ go func() {
+ time.Sleep(time.Second)
+ data = want
+ seq.EndWrite()
+ }()
+ epoch := seq.BeginRead()
+ if data != want {
+ t.Errorf("Reader: got %v, wanted %v", data, want)
+ }
+ if !seq.ReadOk(epoch) {
+ t.Errorf("ReadOk: got false, wanted true")
+ }
+}
+
+func TestSeqCountReadOkAfterWrite(t *testing.T) {
+ var seq SeqCount
+ epoch := seq.BeginRead()
+ seq.BeginWrite()
+ seq.EndWrite()
+ if seq.ReadOk(epoch) {
+ t.Errorf("ReadOk: got true, wanted false")
+ }
+}
+
+func TestSeqCountReadOkDuringWrite(t *testing.T) {
+ var seq SeqCount
+ epoch := seq.BeginRead()
+ seq.BeginWrite()
+ if seq.ReadOk(epoch) {
+ t.Errorf("ReadOk: got true, wanted false")
+ }
+ seq.EndWrite()
+}
+
+func BenchmarkSeqCountWriteUncontended(b *testing.B) {
+ var seq SeqCount
+ for i := 0; i < b.N; i++ {
+ seq.BeginWrite()
+ seq.EndWrite()
+ }
+}
+
+func BenchmarkSeqCountReadUncontended(b *testing.B) {
+ var seq SeqCount
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ epoch := seq.BeginRead()
+ if !seq.ReadOk(epoch) {
+ b.Fatalf("ReadOk: got false, wanted true")
+ }
+ }
+ })
+}
+
+func TestPointersInType(t *testing.T) {
+ for _, test := range []struct {
+ name string // used for both test and value name
+ val interface{}
+ ptrs []string
+ }{
+ {
+ name: "EmptyStruct",
+ val: struct{}{},
+ },
+ {
+ name: "Int",
+ val: int(0),
+ },
+ {
+ name: "MixedStruct",
+ val: struct {
+ b bool
+ I int
+ ExportedPtr *struct{}
+ unexportedPtr *struct{}
+ arr [2]int
+ ptrArr [2]*int
+ nestedStruct struct {
+ nestedNonptr int
+ nestedPtr *int
+ }
+ structArr [1]struct {
+ nonptr int
+ ptr *int
+ }
+ }{},
+ ptrs: []string{
+ "MixedStruct.ExportedPtr",
+ "MixedStruct.unexportedPtr",
+ "MixedStruct.ptrArr[]",
+ "MixedStruct.nestedStruct.nestedPtr",
+ "MixedStruct.structArr[].ptr",
+ },
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ typ := reflect.TypeOf(test.val)
+ ptrs := PointersInType(typ, test.name)
+ t.Logf("Found pointers: %v", ptrs)
+ if (len(ptrs) != 0 || len(test.ptrs) != 0) && !reflect.DeepEqual(ptrs, test.ptrs) {
+ t.Errorf("Got %v, wanted %v", ptrs, test.ptrs)
+ }
+ })
+ }
+}
diff --git a/pkg/syncutil/syncutil.go b/pkg/syncutil/syncutil.go
new file mode 100644
index 000000000..66e750d06
--- /dev/null
+++ b/pkg/syncutil/syncutil.go
@@ -0,0 +1,7 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package syncutil provides synchronization primitives.
+package syncutil
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index 3c2b2b5ea..65d4d0cd8 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -6,6 +6,8 @@ package(licenses = ["notice"])
go_library(
name = "tcpip",
srcs = [
+ "packet_buffer.go",
+ "packet_buffer_state.go",
"tcpip.go",
"time_unsafe.go",
],
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 8ced960bb..ee077ae83 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -151,10 +151,8 @@ func TestCloseReader(t *testing.T) {
buf := make([]byte, 256)
n, err := c.Read(buf)
- got, ok := err.(*net.OpError)
- want := tcpip.ErrConnectionAborted
- if n != 0 || !ok || got.Err.Error() != want.String() {
- t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, err, want)
+ if n != 0 || err != io.EOF {
+ t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, err)
}
}()
sender, err := connect(s, addr)
@@ -203,10 +201,8 @@ func TestCloseReaderWithForwarder(t *testing.T) {
buf := make([]byte, 256)
n, e := c.Read(buf)
- got, ok := e.(*net.OpError)
- want := tcpip.ErrConnectionAborted
- if n != 0 || !ok || got.Err.Error() != want.String() {
- t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, e, want)
+ if n != 0 || e != io.EOF {
+ t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, e)
}
})
s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
diff --git a/pkg/tcpip/buffer/prependable.go b/pkg/tcpip/buffer/prependable.go
index 4287464f3..ba21f4eca 100644
--- a/pkg/tcpip/buffer/prependable.go
+++ b/pkg/tcpip/buffer/prependable.go
@@ -41,6 +41,11 @@ func NewPrependableFromView(v View) Prependable {
return Prependable{buf: v, usedIdx: 0}
}
+// NewEmptyPrependableFromView creates a new prependable buffer from a View.
+func NewEmptyPrependableFromView(v View) Prependable {
+ return Prependable{buf: v, usedIdx: len(v)}
+}
+
// View returns a View of the backing buffer that contains all prepended
// data so far.
func (p Prependable) View() View {
@@ -72,3 +77,9 @@ func (p *Prependable) Prepend(size int) []byte {
p.usedIdx -= size
return p.View()[:size:size]
}
+
+// DeepCopy copies p and the bytes backing it.
+func (p Prependable) DeepCopy() Prependable {
+ p.buf = append(View(nil), p.buf...)
+ return p
+}
diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD
index 4cecfb989..b6fa6fc37 100644
--- a/pkg/tcpip/checker/BUILD
+++ b/pkg/tcpip/checker/BUILD
@@ -10,6 +10,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
+ "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/seqnum",
],
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 02137e1c9..2f15bf1f1 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -22,6 +22,7 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
@@ -639,6 +640,8 @@ func ICMPv4Code(want byte) TransportChecker {
// ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and
// potentially additional ICMPv6 header fields.
+//
+// ICMPv6 will validate the checksum field before calling checkers.
func ICMPv6(checkers ...TransportChecker) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
@@ -650,6 +653,10 @@ func ICMPv6(checkers ...TransportChecker) NetworkChecker {
}
icmp := header.ICMPv6(last.Payload())
+ if got, want := icmp.Checksum(), header.ICMPv6Checksum(icmp, last.SourceAddress(), last.DestinationAddress(), buffer.VectorisedView{}); got != want {
+ t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want)
+ }
+
for _, f := range checkers {
f(t, icmp)
}
diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD
index 0c5c20cea..e648efa71 100644
--- a/pkg/tcpip/hash/jenkins/BUILD
+++ b/pkg/tcpip/hash/jenkins/BUILD
@@ -7,9 +7,7 @@ go_library(
name = "jenkins",
srcs = ["jenkins.go"],
importpath = "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
)
go_test(
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index 07d09abed..f1d837196 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -19,6 +19,7 @@ go_library(
"ndp_neighbor_advert.go",
"ndp_neighbor_solicit.go",
"ndp_options.go",
+ "ndp_router_advert.go",
"tcp.go",
"udp.go",
],
@@ -36,16 +37,29 @@ go_test(
name = "header_x_test",
size = "small",
srcs = [
+ "checksum_test.go",
+ "ipv6_test.go",
"ipversion_test.go",
"tcp_test.go",
],
- deps = [":header"],
+ deps = [
+ ":header",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "@com_github_google_go-cmp//cmp:go_default_library",
+ ],
)
go_test(
name = "header_test",
size = "small",
- srcs = ["ndp_test.go"],
+ srcs = [
+ "eth_test.go",
+ "ndp_test.go",
+ ],
embed = [":header"],
- deps = ["//pkg/tcpip"],
+ deps = [
+ "//pkg/tcpip",
+ "@com_github_google_go-cmp//cmp:go_default_library",
+ ],
)
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go
index 39a4d69be..9749c7f4d 100644
--- a/pkg/tcpip/header/checksum.go
+++ b/pkg/tcpip/header/checksum.go
@@ -23,11 +23,17 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
-func calculateChecksum(buf []byte, initial uint32) uint16 {
+func calculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) {
v := initial
+ if odd {
+ v += uint32(buf[0])
+ buf = buf[1:]
+ }
+
l := len(buf)
- if l&1 != 0 {
+ odd = l&1 != 0
+ if odd {
l--
v += uint32(buf[l]) << 8
}
@@ -36,7 +42,7 @@ func calculateChecksum(buf []byte, initial uint32) uint16 {
v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
}
- return ChecksumCombine(uint16(v), uint16(v>>16))
+ return ChecksumCombine(uint16(v), uint16(v>>16)), odd
}
// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
@@ -44,7 +50,8 @@ func calculateChecksum(buf []byte, initial uint32) uint16 {
//
// The initial checksum must have been computed on an even number of bytes.
func Checksum(buf []byte, initial uint16) uint16 {
- return calculateChecksum(buf, uint32(initial))
+ s, _ := calculateChecksum(buf, false, uint32(initial))
+ return s
}
// ChecksumVV calculates the checksum (as defined in RFC 1071) of the bytes in
@@ -52,19 +59,40 @@ func Checksum(buf []byte, initial uint16) uint16 {
//
// The initial checksum must have been computed on an even number of bytes.
func ChecksumVV(vv buffer.VectorisedView, initial uint16) uint16 {
- var odd bool
+ return ChecksumVVWithOffset(vv, initial, 0, vv.Size())
+}
+
+// ChecksumVVWithOffset calculates the checksum (as defined in RFC 1071) of the
+// bytes in the given VectorizedView.
+//
+// The initial checksum must have been computed on an even number of bytes.
+func ChecksumVVWithOffset(vv buffer.VectorisedView, initial uint16, off int, size int) uint16 {
+ odd := false
sum := initial
for _, v := range vv.Views() {
if len(v) == 0 {
continue
}
- s := uint32(sum)
- if odd {
- s += uint32(v[0])
- v = v[1:]
+
+ if off >= len(v) {
+ off -= len(v)
+ continue
+ }
+ v = v[off:]
+
+ l := len(v)
+ if l > size {
+ l = size
+ }
+ v = v[:l]
+
+ sum, odd = calculateChecksum(v, odd, uint32(sum))
+
+ size -= len(v)
+ if size == 0 {
+ break
}
- odd = len(v)&1 != 0
- sum = calculateChecksum(v, s)
+ off = 0
}
return sum
}
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go
new file mode 100644
index 000000000..86b466c1c
--- /dev/null
+++ b/pkg/tcpip/header/checksum_test.go
@@ -0,0 +1,109 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package header provides the implementation of the encoding and decoding of
+// network protocol headers.
+package header_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+func TestChecksumVVWithOffset(t *testing.T) {
+ testCases := []struct {
+ name string
+ vv buffer.VectorisedView
+ off, size int
+ initial uint16
+ want uint16
+ }{
+ {
+ name: "empty",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}),
+ }),
+ off: 0,
+ size: 0,
+ want: 0,
+ },
+ {
+ name: "OneView",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}),
+ }),
+ off: 0,
+ size: 5,
+ want: 1294,
+ },
+ {
+ name: "TwoViews",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}),
+ buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}),
+ }),
+ off: 0,
+ size: 11,
+ want: 33819,
+ },
+ {
+ name: "TwoViewsWithOffset",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
+ buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}),
+ }),
+ off: 1,
+ size: 11,
+ want: 33819,
+ },
+ {
+ name: "ThreeViewsWithOffset",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
+ buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
+ buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}),
+ }),
+ off: 7,
+ size: 11,
+ want: 33819,
+ },
+ {
+ name: "ThreeViewsWithInitial",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{77, 11, 33, 0, 55, 44}),
+ buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
+ buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123, 99}),
+ }),
+ initial: 77,
+ off: 7,
+ size: 11,
+ want: 33896,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ if got, want := header.ChecksumVVWithOffset(tc.vv, tc.initial, tc.off, tc.size), tc.want; got != want {
+ t.Errorf("header.ChecksumVVWithOffset(%v) = %v, want: %v", tc, got, tc.want)
+ }
+ v := tc.vv.ToView()
+ v.TrimFront(tc.off)
+ v.CapLength(tc.size)
+ if got, want := header.Checksum(v, tc.initial), tc.want; got != want {
+ t.Errorf("header.Checksum(%v) = %v, want: %v", tc, got, tc.want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go
index 4c3d3311f..f5d2c127f 100644
--- a/pkg/tcpip/header/eth.go
+++ b/pkg/tcpip/header/eth.go
@@ -48,8 +48,48 @@ const (
// EthernetAddressSize is the size, in bytes, of an ethernet address.
EthernetAddressSize = 6
+
+ // unspecifiedEthernetAddress is the unspecified ethernet address
+ // (all bits set to 0).
+ unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")
+
+ // unicastMulticastFlagMask is the mask of the least significant bit in
+ // the first octet (in network byte order) of an ethernet address that
+ // determines whether the ethernet address is a unicast or multicast. If
+ // the masked bit is a 1, then the address is a multicast, unicast
+ // otherwise.
+ //
+ // See the IEEE Std 802-2001 document for more details. Specifically,
+ // section 9.2.1 of http://ieee802.org/secmail/pdfocSP2xXA6d.pdf:
+ // "A 48-bit universal address consists of two parts. The first 24 bits
+ // correspond to the OUI as assigned by the IEEE, expect that the
+ // assignee may set the LSB of the first octet to 1 for group addresses
+ // or set it to 0 for individual addresses."
+ unicastMulticastFlagMask = 1
+
+ // unicastMulticastFlagByteIdx is the byte that holds the
+ // unicast/multicast flag. See unicastMulticastFlagMask.
+ unicastMulticastFlagByteIdx = 0
+)
+
+const (
+ // EthernetProtocolAll is a catch-all for all protocols carried inside
+ // an ethernet frame. It is mainly used to create packet sockets that
+ // capture all traffic.
+ EthernetProtocolAll tcpip.NetworkProtocolNumber = 0x0003
+
+ // EthernetProtocolPUP is the PARC Universial Packet protocol ethertype.
+ EthernetProtocolPUP tcpip.NetworkProtocolNumber = 0x0200
)
+// Ethertypes holds the protocol numbers describing the payload of an ethernet
+// frame. These types aren't necessarily supported by netstack, but can be used
+// to catch all traffic of a type via packet endpoints.
+var Ethertypes = []tcpip.NetworkProtocolNumber{
+ EthernetProtocolAll,
+ EthernetProtocolPUP,
+}
+
// SourceAddress returns the "MAC source" field of the ethernet frame header.
func (b Ethernet) SourceAddress() tcpip.LinkAddress {
return tcpip.LinkAddress(b[srcMAC:][:EthernetAddressSize])
@@ -72,3 +112,25 @@ func (b Ethernet) Encode(e *EthernetFields) {
copy(b[srcMAC:][:EthernetAddressSize], e.SrcAddr)
copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr)
}
+
+// IsValidUnicastEthernetAddress returns true if addr is a valid unicast
+// ethernet address.
+func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool {
+ // Must be of the right length.
+ if len(addr) != EthernetAddressSize {
+ return false
+ }
+
+ // Must not be unspecified.
+ if addr == unspecifiedEthernetAddress {
+ return false
+ }
+
+ // Must not be a multicast.
+ if addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0 {
+ return false
+ }
+
+ // addr is a valid unicast ethernet address.
+ return true
+}
diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go
new file mode 100644
index 000000000..6634c90f5
--- /dev/null
+++ b/pkg/tcpip/header/eth_test.go
@@ -0,0 +1,68 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+func TestIsValidUnicastEthernetAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.LinkAddress
+ expected bool
+ }{
+ {
+ "Nil",
+ tcpip.LinkAddress([]byte(nil)),
+ false,
+ },
+ {
+ "Empty",
+ tcpip.LinkAddress(""),
+ false,
+ },
+ {
+ "InvalidLength",
+ tcpip.LinkAddress("\x01\x02\x03"),
+ false,
+ },
+ {
+ "Unspecified",
+ unspecifiedEthernetAddress,
+ false,
+ },
+ {
+ "Multicast",
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ false,
+ },
+ {
+ "Valid",
+ tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06"),
+ true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := IsValidUnicastEthernetAddress(test.addr); got != test.expected {
+ t.Fatalf("got IsValidUnicastEthernetAddress = %t, want = %t", got, test.expected)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index e51c5098c..b4037b6c8 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -80,6 +80,13 @@ const (
// icmpv6SequenceOffset is the offset of the sequence field
// in a ICMPv6 Echo Request/Reply message.
icmpv6SequenceOffset = 6
+
+ // NDPHopLimit is the expected IP hop limit value of 255 for received
+ // NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1,
+ // 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently
+ // drop the NDP packet. All outgoing NDP packets must use this value for
+ // its IP hop limit field.
+ NDPHopLimit = 255
)
// ICMPv6Type is the ICMP type field described in RFC 4443 and friends.
@@ -125,7 +132,7 @@ func (b ICMPv6) Checksum() uint16 {
return binary.BigEndian.Uint16(b[icmpv6ChecksumOffset:])
}
-// SetChecksum calculates and sets the ICMP checksum field.
+// SetChecksum sets the ICMP checksum field.
func (b ICMPv6) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[icmpv6ChecksumOffset:], checksum)
}
@@ -190,7 +197,7 @@ func (b ICMPv6) Payload() []byte {
return b[ICMPv6PayloadOffset:]
}
-// ICMPv6Checksum calculates the ICMP checksum over the provided ICMP header,
+// ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header,
// IPv6 src/dst addresses and the payload.
func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 {
// Calculate the IPv6 pseudo-header upper-layer checksum.
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index b125bbea5..fc671e439 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -90,9 +90,23 @@ const (
// IPv6Any is the non-routable IPv6 "any" meta address. It is also
// known as the unspecified address.
IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+
+ // IIDSize is the size of an interface identifier (IID), in bytes, as
+ // defined by RFC 4291 section 2.5.1.
+ IIDSize = 8
+
+ // IIDOffsetInIPv6Address is the offset, in bytes, from the start
+ // of an IPv6 address to the beginning of the interface identifier
+ // (IID) for auto-generated addresses. That is, all bytes before
+ // the IIDOffsetInIPv6Address-th byte are the prefix bytes, and all
+ // bytes including and after the IIDOffsetInIPv6Address-th byte are
+ // for the IID.
+ IIDOffsetInIPv6Address = 8
)
-// IPv6EmptySubnet is the empty IPv6 subnet.
+// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the
+// catch-all or wildcard subnet. That is, all IPv6 addresses are considered to
+// be contained within this subnet.
var IPv6EmptySubnet = func() tcpip.Subnet {
subnet, err := tcpip.NewSubnet(IPv6Any, tcpip.AddressMask(IPv6Any))
if err != nil {
@@ -101,6 +115,15 @@ var IPv6EmptySubnet = func() tcpip.Subnet {
return subnet
}()
+// IPv6LinkLocalPrefix is the prefix for IPv6 link-local addresses, as defined
+// by RFC 4291 section 2.5.6.
+//
+// The prefix is fe80::/64
+var IPv6LinkLocalPrefix = tcpip.AddressWithPrefix{
+ Address: "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: 64,
+}
+
// PayloadLength returns the value of the "payload length" field of the ipv6
// header.
func (b IPv6) PayloadLength() uint16 {
@@ -255,27 +278,43 @@ func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address {
return solicitedNodeMulticastPrefix + addr[len(addr)-3:]
}
+// EthernetAdddressToModifiedEUI64IntoBuf populates buf with a modified EUI-64
+// from a 48-bit Ethernet/MAC address, as per RFC 4291 section 2.5.1.
+//
+// buf MUST be at least 8 bytes.
+func EthernetAdddressToModifiedEUI64IntoBuf(linkAddr tcpip.LinkAddress, buf []byte) {
+ buf[0] = linkAddr[0] ^ 2
+ buf[1] = linkAddr[1]
+ buf[2] = linkAddr[2]
+ buf[3] = 0xFF
+ buf[4] = 0xFE
+ buf[5] = linkAddr[3]
+ buf[6] = linkAddr[4]
+ buf[7] = linkAddr[5]
+}
+
+// EthernetAddressToModifiedEUI64 computes a modified EUI-64 from a 48-bit
+// Ethernet/MAC address, as per RFC 4291 section 2.5.1.
+func EthernetAddressToModifiedEUI64(linkAddr tcpip.LinkAddress) [IIDSize]byte {
+ var buf [IIDSize]byte
+ EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:])
+ return buf
+}
+
// LinkLocalAddr computes the default IPv6 link-local address from a link-layer
// (MAC) address.
func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address {
- // Convert a 48-bit MAC to an EUI-64 and then prepend the link-local
- // header, FE80::.
+ // Convert a 48-bit MAC to a modified EUI-64 and then prepend the
+ // link-local header, FE80::.
//
// The conversion is very nearly:
// aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff
// Note the capital A. The conversion aa->Aa involves a bit flip.
- lladdrb := [16]byte{
- 0: 0xFE,
- 1: 0x80,
- 8: linkAddr[0] ^ 2,
- 9: linkAddr[1],
- 10: linkAddr[2],
- 11: 0xFF,
- 12: 0xFE,
- 13: linkAddr[3],
- 14: linkAddr[4],
- 15: linkAddr[5],
+ lladdrb := [IPv6AddressSize]byte{
+ 0: 0xFE,
+ 1: 0x80,
}
+ EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, lladdrb[IIDOffsetInIPv6Address:])
return tcpip.Address(lladdrb[:])
}
diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go
new file mode 100644
index 000000000..42c5c6fc1
--- /dev/null
+++ b/pkg/tcpip/header/ipv6_test.go
@@ -0,0 +1,45 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header_test
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+const linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+
+func TestEthernetAdddressToModifiedEUI64(t *testing.T) {
+ expectedIID := [header.IIDSize]byte{0, 2, 3, 255, 254, 4, 5, 6}
+
+ if diff := cmp.Diff(expectedIID, header.EthernetAddressToModifiedEUI64(linkAddr)); diff != "" {
+ t.Errorf("EthernetAddressToModifiedEUI64(%s) mismatch (-want +got):\n%s", linkAddr, diff)
+ }
+
+ var buf [header.IIDSize]byte
+ header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:])
+ if diff := cmp.Diff(expectedIID, buf); diff != "" {
+ t.Errorf("EthernetAddressToModifiedEUI64IntoBuf(%s, _) mismatch (-want +got):\n%s", linkAddr, diff)
+ }
+}
+
+func TestLinkLocalAddr(t *testing.T) {
+ if got, want := header.LinkLocalAddr(linkAddr), tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x02\x03\xff\xfe\x04\x05\x06"); got != want {
+ t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want)
+ }
+}
diff --git a/pkg/tcpip/header/ndp_neighbor_advert.go b/pkg/tcpip/header/ndp_neighbor_advert.go
index 5c2b472c8..505c92668 100644
--- a/pkg/tcpip/header/ndp_neighbor_advert.go
+++ b/pkg/tcpip/header/ndp_neighbor_advert.go
@@ -18,6 +18,8 @@ import "gvisor.dev/gvisor/pkg/tcpip"
// NDPNeighborAdvert is an NDP Neighbor Advertisement message. It will
// only contain the body of an ICMPv6 packet.
+//
+// See RFC 4861 section 4.4 for more details.
type NDPNeighborAdvert []byte
const (
diff --git a/pkg/tcpip/header/ndp_neighbor_solicit.go b/pkg/tcpip/header/ndp_neighbor_solicit.go
index 1dcb0fbc6..3a1b8e139 100644
--- a/pkg/tcpip/header/ndp_neighbor_solicit.go
+++ b/pkg/tcpip/header/ndp_neighbor_solicit.go
@@ -18,6 +18,8 @@ import "gvisor.dev/gvisor/pkg/tcpip"
// NDPNeighborSolicit is an NDP Neighbor Solicitation message. It will only
// contain the body of an ICMPv6 packet.
+//
+// See RFC 4861 section 4.3 for more details.
type NDPNeighborSolicit []byte
const (
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
index b28bde15b..06e0bace2 100644
--- a/pkg/tcpip/header/ndp_options.go
+++ b/pkg/tcpip/header/ndp_options.go
@@ -15,6 +15,11 @@
package header
import (
+ "encoding/binary"
+ "errors"
+ "math"
+ "time"
+
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -27,15 +32,226 @@ const (
// Link Layer Option for an Ethernet address.
ndpTargetEthernetLinkLayerAddressSize = 8
+ // NDPPrefixInformationType is the type of the Prefix Information
+ // option, as per RFC 4861 section 4.6.2.
+ NDPPrefixInformationType = 3
+
+ // ndpPrefixInformationLength is the expected length, in bytes, of the
+ // body of an NDP Prefix Information option, as per RFC 4861 section
+ // 4.6.2 which specifies that the Length field is 4. Given this, the
+ // expected length, in bytes, is 30 becuase 4 * lengthByteUnits (8) - 2
+ // (Type & Length) = 30.
+ ndpPrefixInformationLength = 30
+
+ // ndpPrefixInformationPrefixLengthOffset is the offset of the Prefix
+ // Length field within an NDPPrefixInformation.
+ ndpPrefixInformationPrefixLengthOffset = 0
+
+ // ndpPrefixInformationFlagsOffset is the offset of the flags byte
+ // within an NDPPrefixInformation.
+ ndpPrefixInformationFlagsOffset = 1
+
+ // ndpPrefixInformationOnLinkFlagMask is the mask of the On-Link Flag
+ // field in the flags byte within an NDPPrefixInformation.
+ ndpPrefixInformationOnLinkFlagMask = (1 << 7)
+
+ // ndpPrefixInformationAutoAddrConfFlagMask is the mask of the
+ // Autonomous Address-Configuration flag field in the flags byte within
+ // an NDPPrefixInformation.
+ ndpPrefixInformationAutoAddrConfFlagMask = (1 << 6)
+
+ // ndpPrefixInformationReserved1FlagsMask is the mask of the Reserved1
+ // field in the flags byte within an NDPPrefixInformation.
+ ndpPrefixInformationReserved1FlagsMask = 63
+
+ // ndpPrefixInformationValidLifetimeOffset is the start of the 4-byte
+ // Valid Lifetime field within an NDPPrefixInformation.
+ ndpPrefixInformationValidLifetimeOffset = 2
+
+ // ndpPrefixInformationPreferredLifetimeOffset is the start of the
+ // 4-byte Preferred Lifetime field within an NDPPrefixInformation.
+ ndpPrefixInformationPreferredLifetimeOffset = 6
+
+ // ndpPrefixInformationReserved2Offset is the start of the 4-byte
+ // Reserved2 field within an NDPPrefixInformation.
+ ndpPrefixInformationReserved2Offset = 10
+
+ // ndpPrefixInformationReserved2Length is the length of the Reserved2
+ // field.
+ //
+ // It is 4 bytes.
+ ndpPrefixInformationReserved2Length = 4
+
+ // ndpPrefixInformationPrefixOffset is the start of the Prefix field
+ // within an NDPPrefixInformation.
+ ndpPrefixInformationPrefixOffset = 14
+
+ // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS
+ // Server option, as per RFC 8106 section 5.1.
+ NDPRecursiveDNSServerOptionType = 25
+
+ // ndpRecursiveDNSServerLifetimeOffset is the start of the 4-byte
+ // Lifetime field within an NDPRecursiveDNSServer.
+ ndpRecursiveDNSServerLifetimeOffset = 2
+
+ // ndpRecursiveDNSServerAddressesOffset is the start of the addresses
+ // for IPv6 Recursive DNS Servers within an NDPRecursiveDNSServer.
+ ndpRecursiveDNSServerAddressesOffset = 6
+
+ // minNDPRecursiveDNSServerLength is the minimum NDP Recursive DNS
+ // Server option's length field value when it contains at least one
+ // IPv6 address.
+ minNDPRecursiveDNSServerLength = 3
+
// lengthByteUnits is the multiplier factor for the Length field of an
// NDP option. That is, the length field for NDP options is in units of
// 8 octets, as per RFC 4861 section 4.6.
lengthByteUnits = 8
)
+var (
+ // NDPInfiniteLifetime is a value that represents infinity for the
+ // 4-byte lifetime fields found in various NDP options. Its value is
+ // (2^32 - 1)s = 4294967295s.
+ //
+ // This is a variable instead of a constant so that tests can change
+ // this value to a smaller value. It should only be modified by tests.
+ NDPInfiniteLifetime = time.Second * math.MaxUint32
+)
+
+// NDPOptionIterator is an iterator of NDPOption.
+//
+// Note, between when an NDPOptionIterator is obtained and last used, no changes
+// to the NDPOptions may happen. Doing so may cause undefined and unexpected
+// behaviour. It is fine to obtain an NDPOptionIterator, iterate over the first
+// few NDPOption then modify the backing NDPOptions so long as the
+// NDPOptionIterator obtained before modification is no longer used.
+type NDPOptionIterator struct {
+ // The NDPOptions this NDPOptionIterator is iterating over.
+ opts NDPOptions
+}
+
+// Potential errors when iterating over an NDPOptions.
+var (
+ ErrNDPOptBufExhausted = errors.New("Buffer unexpectedly exhausted")
+ ErrNDPOptZeroLength = errors.New("NDP option has zero-valued Length field")
+ ErrNDPOptMalformedBody = errors.New("NDP option has a malformed body")
+ ErrNDPInvalidLength = errors.New("NDP option's Length value is invalid as per relevant RFC")
+)
+
+// Next returns the next element in the backing NDPOptions, or true if we are
+// done, or false if an error occured.
+//
+// The return can be read as option, done, error. Note, option should only be
+// used if done is false and error is nil.
+func (i *NDPOptionIterator) Next() (NDPOption, bool, error) {
+ for {
+ // Do we still have elements to look at?
+ if len(i.opts) == 0 {
+ return nil, true, nil
+ }
+
+ // Do we have enough bytes for an NDP option that has a Length
+ // field of at least 1? Note, 0 in the Length field is invalid.
+ if len(i.opts) < lengthByteUnits {
+ return nil, true, ErrNDPOptBufExhausted
+ }
+
+ // Get the Type field.
+ t := i.opts[0]
+
+ // Get the Length field.
+ l := i.opts[1]
+
+ // This would indicate an erroneous NDP option as the Length
+ // field should never be 0.
+ if l == 0 {
+ return nil, true, ErrNDPOptZeroLength
+ }
+
+ // How many bytes are in the option body?
+ numBytes := int(l) * lengthByteUnits
+ numBodyBytes := numBytes - 2
+
+ potentialBody := i.opts[2:]
+
+ // This would indicate an erroenous NDPOptions buffer as we ran
+ // out of the buffer in the middle of an NDP option.
+ if left := len(potentialBody); left < numBodyBytes {
+ return nil, true, ErrNDPOptBufExhausted
+ }
+
+ // Get only the options body, leaving the rest of the options
+ // buffer alone.
+ body := potentialBody[:numBodyBytes]
+
+ // Update opts with the remaining options body.
+ i.opts = i.opts[numBytes:]
+
+ switch t {
+ case NDPTargetLinkLayerAddressOptionType:
+ return NDPTargetLinkLayerAddressOption(body), false, nil
+
+ case NDPPrefixInformationType:
+ // Make sure the length of a Prefix Information option
+ // body is ndpPrefixInformationLength, as per RFC 4861
+ // section 4.6.2.
+ if numBodyBytes != ndpPrefixInformationLength {
+ return nil, true, ErrNDPOptMalformedBody
+ }
+
+ return NDPPrefixInformation(body), false, nil
+
+ case NDPRecursiveDNSServerOptionType:
+ // RFC 8106 section 5.3.1 outlines that the RDNSS option
+ // must have a minimum length of 3 so it contains at
+ // least one IPv6 address.
+ if l < minNDPRecursiveDNSServerLength {
+ return nil, true, ErrNDPInvalidLength
+ }
+
+ opt := NDPRecursiveDNSServer(body)
+ if len(opt.Addresses()) == 0 {
+ return nil, true, ErrNDPOptMalformedBody
+ }
+
+ return opt, false, nil
+
+ default:
+ // We do not yet recognize the option, just skip for
+ // now. This is okay because RFC 4861 allows us to
+ // skip/ignore any unrecognized options. However,
+ // we MUST recognized all the options in RFC 4861.
+ //
+ // TODO(b/141487990): Handle all NDP options as defined
+ // by RFC 4861.
+ }
+ }
+}
+
// NDPOptions is a buffer of NDP options as defined by RFC 4861 section 4.6.
type NDPOptions []byte
+// Iter returns an iterator of NDPOption.
+//
+// If check is true, Iter will do an integrity check on the options by iterating
+// over it and returning an error if detected.
+//
+// See NDPOptionIterator for more information.
+func (b NDPOptions) Iter(check bool) (NDPOptionIterator, error) {
+ it := NDPOptionIterator{opts: b}
+
+ if check {
+ for it2 := it; true; {
+ if _, done, err := it2.Next(); err != nil || done {
+ return it, err
+ }
+ }
+ }
+
+ return it, nil
+}
+
// Serialize serializes the provided list of NDP options into o.
//
// Note, b must be of sufficient size to hold all the options in s. See
@@ -75,15 +291,15 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int {
return done
}
-// ndpOption is the set of functions to be implemented by all NDP option types.
-type ndpOption interface {
- // Type returns the type of this ndpOption.
+// NDPOption is the set of functions to be implemented by all NDP option types.
+type NDPOption interface {
+ // Type returns the type of the receiver.
Type() uint8
- // Length returns the length of the body of this ndpOption, in bytes.
+ // Length returns the length of the body of the receiver, in bytes.
Length() int
- // serializeInto serializes this ndpOption into the provided byte
+ // serializeInto serializes the receiver into the provided byte
// buffer.
//
// Note, the caller MUST provide a byte buffer with size of at least
@@ -92,15 +308,15 @@ type ndpOption interface {
// buffer is not of sufficient size.
//
// serializeInto will return the number of bytes that was used to
- // serialize this ndpOption. Implementers must only use the number of
- // bytes required to serialize this ndpOption. Callers MAY provide a
+ // serialize the receiver. Implementers must only use the number of
+ // bytes required to serialize the receiver. Callers MAY provide a
// larger buffer than required to serialize into.
serializeInto([]byte) int
}
// paddedLength returns the length of o, in bytes, with any padding bytes, if
// required.
-func paddedLength(o ndpOption) int {
+func paddedLength(o NDPOption) int {
l := o.Length()
if l == 0 {
@@ -139,7 +355,7 @@ func paddedLength(o ndpOption) int {
}
// NDPOptionsSerializer is a serializer for NDP options.
-type NDPOptionsSerializer []ndpOption
+type NDPOptionsSerializer []NDPOption
// Length returns the total number of bytes required to serialize.
func (b NDPOptionsSerializer) Length() int {
@@ -154,19 +370,220 @@ func (b NDPOptionsSerializer) Length() int {
// NDPTargetLinkLayerAddressOption is the NDP Target Link Layer Option
// as defined by RFC 4861 section 4.6.1.
+//
+// It is the first X bytes following the NDP option's Type and Length field
+// where X is the value in Length multiplied by lengthByteUnits - 2 bytes.
type NDPTargetLinkLayerAddressOption tcpip.LinkAddress
-// Type implements ndpOption.Type.
+// Type implements NDPOption.Type.
func (o NDPTargetLinkLayerAddressOption) Type() uint8 {
return NDPTargetLinkLayerAddressOptionType
}
-// Length implements ndpOption.Length.
+// Length implements NDPOption.Length.
func (o NDPTargetLinkLayerAddressOption) Length() int {
return len(o)
}
-// serializeInto implements ndpOption.serializeInto.
+// serializeInto implements NDPOption.serializeInto.
func (o NDPTargetLinkLayerAddressOption) serializeInto(b []byte) int {
return copy(b, o)
}
+
+// EthernetAddress will return an ethernet (MAC) address if the
+// NDPTargetLinkLayerAddressOption's body has at minimum EthernetAddressSize
+// bytes. If the body has more than EthernetAddressSize bytes, only the first
+// EthernetAddressSize bytes are returned as that is all that is needed for an
+// Ethernet address.
+func (o NDPTargetLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress {
+ if len(o) >= EthernetAddressSize {
+ return tcpip.LinkAddress(o[:EthernetAddressSize])
+ }
+
+ return tcpip.LinkAddress([]byte(nil))
+}
+
+// NDPPrefixInformation is the NDP Prefix Information option as defined by
+// RFC 4861 section 4.6.2.
+//
+// The length, in bytes, of a valid NDP Prefix Information option body MUST be
+// ndpPrefixInformationLength bytes.
+type NDPPrefixInformation []byte
+
+// Type implements NDPOption.Type.
+func (o NDPPrefixInformation) Type() uint8 {
+ return NDPPrefixInformationType
+}
+
+// Length implements NDPOption.Length.
+func (o NDPPrefixInformation) Length() int {
+ return ndpPrefixInformationLength
+}
+
+// serializeInto implements NDPOption.serializeInto.
+func (o NDPPrefixInformation) serializeInto(b []byte) int {
+ used := copy(b, o)
+
+ // Zero out the Reserved1 field.
+ b[ndpPrefixInformationFlagsOffset] &^= ndpPrefixInformationReserved1FlagsMask
+
+ // Zero out the Reserved2 field.
+ reserved2 := b[ndpPrefixInformationReserved2Offset:][:ndpPrefixInformationReserved2Length]
+ for i := range reserved2 {
+ reserved2[i] = 0
+ }
+
+ return used
+}
+
+// PrefixLength returns the value in the number of leading bits in the Prefix
+// that are valid.
+//
+// Valid values are in the range [0, 128], but o may not always contain valid
+// values. It is up to the caller to valdiate the Prefix Information option.
+func (o NDPPrefixInformation) PrefixLength() uint8 {
+ return o[ndpPrefixInformationPrefixLengthOffset]
+}
+
+// OnLinkFlag returns true of the prefix is considered on-link. On-link means
+// that a forwarding node is not needed to send packets to other nodes on the
+// same prefix.
+//
+// Note, when this function returns false, no statement is made about the
+// on-link property of a prefix. That is, if OnLinkFlag returns false, the
+// caller MUST NOT conclude that the prefix is off-link and MUST NOT update any
+// previously stored state for this prefix about its on-link status.
+func (o NDPPrefixInformation) OnLinkFlag() bool {
+ return o[ndpPrefixInformationFlagsOffset]&ndpPrefixInformationOnLinkFlagMask != 0
+}
+
+// AutonomousAddressConfigurationFlag returns true if the prefix can be used for
+// Stateless Address Auto-Configuration (as specified in RFC 4862).
+func (o NDPPrefixInformation) AutonomousAddressConfigurationFlag() bool {
+ return o[ndpPrefixInformationFlagsOffset]&ndpPrefixInformationAutoAddrConfFlagMask != 0
+}
+
+// ValidLifetime returns the length of time that the prefix is valid for the
+// purpose of on-link determination. This value is relative to the send time of
+// the packet that the Prefix Information option was present in.
+//
+// Note, a value of 0 implies the prefix should not be considered as on-link,
+// and a value of infinity/forever is represented by
+// NDPInfiniteLifetime.
+func (o NDPPrefixInformation) ValidLifetime() time.Duration {
+ // The field is the time in seconds, as per RFC 4861 section 4.6.2.
+ return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpPrefixInformationValidLifetimeOffset:]))
+}
+
+// PreferredLifetime returns the length of time that an address generated from
+// the prefix via Stateless Address Auto-Configuration remains preferred. This
+// value is relative to the send time of the packet that the Prefix Information
+// option was present in.
+//
+// Note, a value of 0 implies that addresses generated from the prefix should
+// no longer remain preferred, and a value of infinity is represented by
+// NDPInfiniteLifetime.
+//
+// Also note that the value of this field MUST NOT exceed the Valid Lifetime
+// field to avoid preferring addresses that are no longer valid, for the
+// purpose of Stateless Address Auto-Configuration.
+func (o NDPPrefixInformation) PreferredLifetime() time.Duration {
+ // The field is the time in seconds, as per RFC 4861 section 4.6.2.
+ return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpPrefixInformationPreferredLifetimeOffset:]))
+}
+
+// Prefix returns an IPv6 address or a prefix of an IPv6 address. The Prefix
+// Length field (see NDPPrefixInformation.PrefixLength) contains the number
+// of valid leading bits in the prefix.
+//
+// Hosts SHOULD ignore an NDP Prefix Information option where the Prefix field
+// holds the link-local prefix (fe80::).
+func (o NDPPrefixInformation) Prefix() tcpip.Address {
+ return tcpip.Address(o[ndpPrefixInformationPrefixOffset:][:IPv6AddressSize])
+}
+
+// Subnet returns the Prefix field and Prefix Length field represented in a
+// tcpip.Subnet.
+func (o NDPPrefixInformation) Subnet() tcpip.Subnet {
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: o.Prefix(),
+ PrefixLen: int(o.PrefixLength()),
+ }
+ return addrWithPrefix.Subnet()
+}
+
+// NDPRecursiveDNSServer is the NDP Recursive DNS Server option, as defined by
+// RFC 8106 section 5.1.
+//
+// To make sure that the option meets its minimum length and does not end in the
+// middle of a DNS server's IPv6 address, the length of a valid
+// NDPRecursiveDNSServer must meet the following constraint:
+// (Length - ndpRecursiveDNSServerAddressesOffset) % IPv6AddressSize == 0
+type NDPRecursiveDNSServer []byte
+
+// Type returns the type of an NDP Recursive DNS Server option.
+//
+// Type implements NDPOption.Type.
+func (NDPRecursiveDNSServer) Type() uint8 {
+ return NDPRecursiveDNSServerOptionType
+}
+
+// Length implements NDPOption.Length.
+func (o NDPRecursiveDNSServer) Length() int {
+ return len(o)
+}
+
+// serializeInto implements NDPOption.serializeInto.
+func (o NDPRecursiveDNSServer) serializeInto(b []byte) int {
+ used := copy(b, o)
+
+ // Zero out the reserved bytes that are before the Lifetime field.
+ for i := 0; i < ndpRecursiveDNSServerLifetimeOffset; i++ {
+ b[i] = 0
+ }
+
+ return used
+}
+
+// Lifetime returns the length of time that the DNS server addresses
+// in this option may be used for name resolution.
+//
+// Note, a value of 0 implies the addresses should no longer be used,
+// and a value of infinity/forever is represented by NDPInfiniteLifetime.
+//
+// Lifetime may panic if o does not have enough bytes to hold the Lifetime
+// field.
+func (o NDPRecursiveDNSServer) Lifetime() time.Duration {
+ // The field is the time in seconds, as per RFC 8106 section 5.1.
+ return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpRecursiveDNSServerLifetimeOffset:]))
+}
+
+// Addresses returns the recursive DNS server IPv6 addresses that may be
+// used for name resolution.
+//
+// Note, some of the addresses returned MAY be link-local addresses.
+//
+// Addresses may panic if o does not hold valid IPv6 addresses.
+func (o NDPRecursiveDNSServer) Addresses() []tcpip.Address {
+ l := len(o)
+ if l < ndpRecursiveDNSServerAddressesOffset {
+ return nil
+ }
+
+ l -= ndpRecursiveDNSServerAddressesOffset
+ if l%IPv6AddressSize != 0 {
+ return nil
+ }
+
+ buf := o[ndpRecursiveDNSServerAddressesOffset:]
+ var addrs []tcpip.Address
+ for len(buf) > 0 {
+ addr := tcpip.Address(buf[:IPv6AddressSize])
+ if !IsV6UnicastAddress(addr) {
+ return nil
+ }
+ addrs = append(addrs, addr)
+ buf = buf[IPv6AddressSize:]
+ }
+ return addrs
+}
diff --git a/pkg/tcpip/header/ndp_router_advert.go b/pkg/tcpip/header/ndp_router_advert.go
new file mode 100644
index 000000000..bf7610863
--- /dev/null
+++ b/pkg/tcpip/header/ndp_router_advert.go
@@ -0,0 +1,112 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+ "time"
+)
+
+// NDPRouterAdvert is an NDP Router Advertisement message. It will only contain
+// the body of an ICMPv6 packet.
+//
+// See RFC 4861 section 4.2 for more details.
+type NDPRouterAdvert []byte
+
+const (
+ // NDPRAMinimumSize is the minimum size of a valid NDP Router
+ // Advertisement message (body of an ICMPv6 packet).
+ NDPRAMinimumSize = 12
+
+ // ndpRACurrHopLimitOffset is the byte of the Curr Hop Limit field
+ // within an NDPRouterAdvert.
+ ndpRACurrHopLimitOffset = 0
+
+ // ndpRAFlagsOffset is the byte with the NDP RA bit-fields/flags
+ // within an NDPRouterAdvert.
+ ndpRAFlagsOffset = 1
+
+ // ndpRAManagedAddrConfFlagMask is the mask of the Managed Address
+ // Configuration flag within the bit-field/flags byte of an
+ // NDPRouterAdvert.
+ ndpRAManagedAddrConfFlagMask = (1 << 7)
+
+ // ndpRAOtherConfFlagMask is the mask of the Other Configuration flag
+ // within the bit-field/flags byte of an NDPRouterAdvert.
+ ndpRAOtherConfFlagMask = (1 << 6)
+
+ // ndpRARouterLifetimeOffset is the start of the 2-byte Router Lifetime
+ // field within an NDPRouterAdvert.
+ ndpRARouterLifetimeOffset = 2
+
+ // ndpRAReachableTimeOffset is the start of the 4-byte Reachable Time
+ // field within an NDPRouterAdvert.
+ ndpRAReachableTimeOffset = 4
+
+ // ndpRARetransTimerOffset is the start of the 4-byte Retrans Timer
+ // field within an NDPRouterAdvert.
+ ndpRARetransTimerOffset = 8
+
+ // ndpRAOptionsOffset is the start of the NDP options in an
+ // NDPRouterAdvert.
+ ndpRAOptionsOffset = 12
+)
+
+// CurrHopLimit returns the value of the Curr Hop Limit field.
+func (b NDPRouterAdvert) CurrHopLimit() uint8 {
+ return b[ndpRACurrHopLimitOffset]
+}
+
+// ManagedAddrConfFlag returns the value of the Managed Address Configuration
+// flag.
+func (b NDPRouterAdvert) ManagedAddrConfFlag() bool {
+ return b[ndpRAFlagsOffset]&ndpRAManagedAddrConfFlagMask != 0
+}
+
+// OtherConfFlag returns the value of the Other Configuration flag.
+func (b NDPRouterAdvert) OtherConfFlag() bool {
+ return b[ndpRAFlagsOffset]&ndpRAOtherConfFlagMask != 0
+}
+
+// RouterLifetime returns the lifetime associated with the default router. A
+// value of 0 means the source of the Router Advertisement is not a default
+// router and SHOULD NOT appear on the default router list. Note, a value of 0
+// only means that the router should not be used as a default router, it does
+// not apply to other information contained in the Router Advertisement.
+func (b NDPRouterAdvert) RouterLifetime() time.Duration {
+ // The field is the time in seconds, as per RFC 4861 section 4.2.
+ return time.Second * time.Duration(binary.BigEndian.Uint16(b[ndpRARouterLifetimeOffset:]))
+}
+
+// ReachableTime returns the time that a node assumes a neighbor is reachable
+// after having received a reachability confirmation. A value of 0 means
+// that it is unspecified by the source of the Router Advertisement message.
+func (b NDPRouterAdvert) ReachableTime() time.Duration {
+ // The field is the time in milliseconds, as per RFC 4861 section 4.2.
+ return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRAReachableTimeOffset:]))
+}
+
+// RetransTimer returns the time between retransmitted Neighbor Solicitation
+// messages. A value of 0 means that it is unspecified by the source of the
+// Router Advertisement message.
+func (b NDPRouterAdvert) RetransTimer() time.Duration {
+ // The field is the time in milliseconds, as per RFC 4861 section 4.2.
+ return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRARetransTimerOffset:]))
+}
+
+// Options returns an NDPOptions of the the options body.
+func (b NDPRouterAdvert) Options() NDPOptions {
+ return NDPOptions(b[ndpRAOptionsOffset:])
+}
diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go
index a431a6e61..2c439d70c 100644
--- a/pkg/tcpip/header/ndp_test.go
+++ b/pkg/tcpip/header/ndp_test.go
@@ -17,7 +17,9 @@ package header
import (
"bytes"
"testing"
+ "time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -35,18 +37,18 @@ func TestNDPNeighborSolicit(t *testing.T) {
ns := NDPNeighborSolicit(b)
addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10")
if got := ns.TargetAddress(); got != addr {
- t.Fatalf("got ns.TargetAddress = %s, want %s", got, addr)
+ t.Errorf("got ns.TargetAddress = %s, want %s", got, addr)
}
// Test updating the Target Address.
addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11")
ns.SetTargetAddress(addr2)
if got := ns.TargetAddress(); got != addr2 {
- t.Fatalf("got ns.TargetAddress = %s, want %s", got, addr2)
+ t.Errorf("got ns.TargetAddress = %s, want %s", got, addr2)
}
// Make sure the address got updated in the backing buffer.
if got := tcpip.Address(b[ndpNSTargetAddessOffset:][:IPv6AddressSize]); got != addr2 {
- t.Fatalf("got targetaddress buffer = %s, want %s", got, addr2)
+ t.Errorf("got targetaddress buffer = %s, want %s", got, addr2)
}
}
@@ -64,59 +66,129 @@ func TestNDPNeighborAdvert(t *testing.T) {
na := NDPNeighborAdvert(b)
addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10")
if got := na.TargetAddress(); got != addr {
- t.Fatalf("got TargetAddress = %s, want %s", got, addr)
+ t.Errorf("got TargetAddress = %s, want %s", got, addr)
}
// Test getting the Router Flag.
if got := na.RouterFlag(); !got {
- t.Fatalf("got RouterFlag = false, want = true")
+ t.Errorf("got RouterFlag = false, want = true")
}
// Test getting the Solicited Flag.
if got := na.SolicitedFlag(); got {
- t.Fatalf("got SolicitedFlag = true, want = false")
+ t.Errorf("got SolicitedFlag = true, want = false")
}
// Test getting the Override Flag.
if got := na.OverrideFlag(); !got {
- t.Fatalf("got OverrideFlag = false, want = true")
+ t.Errorf("got OverrideFlag = false, want = true")
}
// Test updating the Target Address.
addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11")
na.SetTargetAddress(addr2)
if got := na.TargetAddress(); got != addr2 {
- t.Fatalf("got TargetAddress = %s, want %s", got, addr2)
+ t.Errorf("got TargetAddress = %s, want %s", got, addr2)
}
// Make sure the address got updated in the backing buffer.
if got := tcpip.Address(b[ndpNATargetAddressOffset:][:IPv6AddressSize]); got != addr2 {
- t.Fatalf("got targetaddress buffer = %s, want %s", got, addr2)
+ t.Errorf("got targetaddress buffer = %s, want %s", got, addr2)
}
// Test updating the Router Flag.
na.SetRouterFlag(false)
if got := na.RouterFlag(); got {
- t.Fatalf("got RouterFlag = true, want = false")
+ t.Errorf("got RouterFlag = true, want = false")
}
// Test updating the Solicited Flag.
na.SetSolicitedFlag(true)
if got := na.SolicitedFlag(); !got {
- t.Fatalf("got SolicitedFlag = false, want = true")
+ t.Errorf("got SolicitedFlag = false, want = true")
}
// Test updating the Override Flag.
na.SetOverrideFlag(false)
if got := na.OverrideFlag(); got {
- t.Fatalf("got OverrideFlag = true, want = false")
+ t.Errorf("got OverrideFlag = true, want = false")
}
// Make sure flags got updated in the backing buffer.
if got := b[ndpNAFlagsOffset]; got != 64 {
- t.Fatalf("got flags byte = %d, want = 64")
+ t.Errorf("got flags byte = %d, want = 64")
}
}
+func TestNDPRouterAdvert(t *testing.T) {
+ b := []byte{
+ 64, 128, 1, 2,
+ 3, 4, 5, 6,
+ 7, 8, 9, 10,
+ }
+
+ ra := NDPRouterAdvert(b)
+
+ if got := ra.CurrHopLimit(); got != 64 {
+ t.Errorf("got ra.CurrHopLimit = %d, want = 64", got)
+ }
+
+ if got := ra.ManagedAddrConfFlag(); !got {
+ t.Errorf("got ManagedAddrConfFlag = false, want = true")
+ }
+
+ if got := ra.OtherConfFlag(); got {
+ t.Errorf("got OtherConfFlag = true, want = false")
+ }
+
+ if got, want := ra.RouterLifetime(), time.Second*258; got != want {
+ t.Errorf("got ra.RouterLifetime = %d, want = %d", got, want)
+ }
+
+ if got, want := ra.ReachableTime(), time.Millisecond*50595078; got != want {
+ t.Errorf("got ra.ReachableTime = %d, want = %d", got, want)
+ }
+
+ if got, want := ra.RetransTimer(), time.Millisecond*117967114; got != want {
+ t.Errorf("got ra.RetransTimer = %d, want = %d", got, want)
+ }
+}
+
+// TestNDPTargetLinkLayerAddressOptionEthernetAddress tests getting the
+// Ethernet address from an NDPTargetLinkLayerAddressOption.
+func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ expected tcpip.LinkAddress
+ }{
+ {
+ "ValidMAC",
+ []byte{1, 2, 3, 4, 5, 6},
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ },
+ {
+ "TLLBodyTooShort",
+ []byte{1, 2, 3, 4, 5},
+ tcpip.LinkAddress([]byte(nil)),
+ },
+ {
+ "TLLBodyLargerThanNeeded",
+ []byte{1, 2, 3, 4, 5, 6, 7, 8},
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ tll := NDPTargetLinkLayerAddressOption(test.buf)
+ if got := tll.EthernetAddress(); got != test.expected {
+ t.Errorf("got tll.EthernetAddress = %s, want = %s", got, test.expected)
+ }
+ })
+ }
+
+}
+
// TestNDPTargetLinkLayerAddressOptionSerialize tests serializing a
// NDPTargetLinkLayerAddressOption.
func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) {
@@ -159,6 +231,555 @@ func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) {
if !bytes.Equal(test.buf, test.expectedBuf) {
t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf)
}
+
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ if len(test.expectedBuf) > 0 {
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType {
+ t.Fatalf("got Type %= %d, want = %d", got, NDPTargetLinkLayerAddressOptionType)
+ }
+ tll := next.(NDPTargetLinkLayerAddressOption)
+ if got, want := []byte(tll), test.expectedBuf[2:]; !bytes.Equal(got, want) {
+ t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
+ }
+
+ if got, want := tll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want {
+ t.Errorf("got tll.MACAddress = %s, want = %s", got, want)
+ }
+ }
+
+ // Iterator should not return anything else.
+ next, done, err := it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
+ })
+ }
+}
+
+// TestNDPPrefixInformationOption tests the field getters and serialization of a
+// NDPPrefixInformation.
+func TestNDPPrefixInformationOption(t *testing.T) {
+ b := []byte{
+ 43, 127,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 5, 5, 5, 5,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 21, 22, 23, 24,
+ }
+
+ targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
+ opts := NDPOptions(targetBuf)
+ serializer := NDPOptionsSerializer{
+ NDPPrefixInformation(b),
+ }
+ opts.Serialize(serializer)
+ expectedBuf := []byte{
+ 3, 4, 43, 64,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 21, 22, 23, 24,
+ }
+ if !bytes.Equal(targetBuf, expectedBuf) {
+ t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expectedBuf)
+ }
+
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got := next.Type(); got != NDPPrefixInformationType {
+ t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType)
+ }
+
+ pi := next.(NDPPrefixInformation)
+
+ if got := pi.Type(); got != 3 {
+ t.Errorf("got Type = %d, want = 3", got)
+ }
+
+ if got := pi.Length(); got != 30 {
+ t.Errorf("got Length = %d, want = 30", got)
+ }
+
+ if got := pi.PrefixLength(); got != 43 {
+ t.Errorf("got PrefixLength = %d, want = 43", got)
+ }
+
+ if pi.OnLinkFlag() {
+ t.Error("got OnLinkFlag = true, want = false")
+ }
+
+ if !pi.AutonomousAddressConfigurationFlag() {
+ t.Error("got AutonomousAddressConfigurationFlag = false, want = true")
+ }
+
+ if got, want := pi.ValidLifetime(), 16909060*time.Second; got != want {
+ t.Errorf("got ValidLifetime = %d, want = %d", got, want)
+ }
+
+ if got, want := pi.PreferredLifetime(), 84281096*time.Second; got != want {
+ t.Errorf("got PreferredLifetime = %d, want = %d", got, want)
+ }
+
+ if got, want := pi.Prefix(), tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18"); got != want {
+ t.Errorf("got Prefix = %s, want = %s", got, want)
+ }
+
+ // Iterator should not return anything else.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
+}
+
+func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) {
+ b := []byte{
+ 9, 8,
+ 1, 2, 4, 8,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ }
+ targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
+ expected := []byte{
+ 25, 3, 0, 0,
+ 1, 2, 4, 8,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ }
+ opts := NDPOptions(targetBuf)
+ serializer := NDPOptionsSerializer{
+ NDPRecursiveDNSServer(b),
+ }
+ if got, want := opts.Serialize(serializer), len(expected); got != want {
+ t.Errorf("got Serialize = %d, want = %d", got, want)
+ }
+ if !bytes.Equal(targetBuf, expected) {
+ t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected)
+ }
+
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got := next.Type(); got != NDPRecursiveDNSServerOptionType {
+ t.Errorf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType)
+ }
+
+ opt, ok := next.(NDPRecursiveDNSServer)
+ if !ok {
+ t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next)
+ }
+ if got := opt.Type(); got != 25 {
+ t.Errorf("got Type = %d, want = 31", got)
+ }
+ if got := opt.Length(); got != 22 {
+ t.Errorf("got Length = %d, want = 22", got)
+ }
+ if got, want := opt.Lifetime(), 16909320*time.Second; got != want {
+ t.Errorf("got Lifetime = %s, want = %s", got, want)
+ }
+ want := []tcpip.Address{
+ "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
+ }
+ if got := opt.Addresses(); !cmp.Equal(got, want) {
+ t.Errorf("got Addresses = %v, want = %v", got, want)
+ }
+
+ // Iterator should not return anything else.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
+}
+
+func TestNDPRecursiveDNSServerOption(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ lifetime time.Duration
+ addrs []tcpip.Address
+ }{
+ {
+ "Valid1Addr",
+ []byte{
+ 25, 3, 0, 0,
+ 0, 0, 0, 0,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ },
+ 0,
+ []tcpip.Address{
+ "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
+ },
+ },
+ {
+ "Valid2Addr",
+ []byte{
+ 25, 5, 0, 0,
+ 0, 0, 0, 0,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16,
+ },
+ 0,
+ []tcpip.Address{
+ "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
+ "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10",
+ },
+ },
+ {
+ "Valid3Addr",
+ []byte{
+ 25, 7, 0, 0,
+ 0, 0, 0, 0,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16,
+ 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 17,
+ },
+ 0,
+ []tcpip.Address{
+ "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
+ "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10",
+ "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x11",
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := NDPOptions(test.buf)
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ // Iterator should get our option.
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got := next.Type(); got != NDPRecursiveDNSServerOptionType {
+ t.Fatalf("got Type %= %d, want = %d", got, NDPRecursiveDNSServerOptionType)
+ }
+
+ opt, ok := next.(NDPRecursiveDNSServer)
+ if !ok {
+ t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next)
+ }
+ if got := opt.Lifetime(); got != test.lifetime {
+ t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime)
+ }
+ if got := opt.Addresses(); !cmp.Equal(got, test.addrs) {
+ t.Errorf("got Addresses = %v, want = %v", got, test.addrs)
+ }
+
+ // Iterator should not return anything else.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
})
}
}
+
+// TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions
+// the iterator was returned for is malformed.
+func TestNDPOptionsIterCheck(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ expected error
+ }{
+ {
+ "ZeroLengthField",
+ []byte{0, 0, 0, 0, 0, 0, 0, 0},
+ ErrNDPOptZeroLength,
+ },
+ {
+ "ValidTargetLinkLayerAddressOption",
+ []byte{2, 1, 1, 2, 3, 4, 5, 6},
+ nil,
+ },
+ {
+ "TooSmallTargetLinkLayerAddressOption",
+ []byte{2, 1, 1, 2, 3, 4, 5},
+ ErrNDPOptBufExhausted,
+ },
+ {
+ "ValidPrefixInformation",
+ []byte{
+ 3, 4, 43, 64,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 21, 22, 23, 24,
+ },
+ nil,
+ },
+ {
+ "TooSmallPrefixInformation",
+ []byte{
+ 3, 4, 43, 64,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 21, 22, 23,
+ },
+ ErrNDPOptBufExhausted,
+ },
+ {
+ "InvalidPrefixInformationLength",
+ []byte{
+ 3, 3, 43, 64,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ },
+ ErrNDPOptMalformedBody,
+ },
+ {
+ "ValidTargetLinkLayerAddressWithPrefixInformation",
+ []byte{
+ // Target Link-Layer Address.
+ 2, 1, 1, 2, 3, 4, 5, 6,
+
+ // Prefix information.
+ 3, 4, 43, 64,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 21, 22, 23, 24,
+ },
+ nil,
+ },
+ {
+ "ValidTargetLinkLayerAddressWithPrefixInformationWithUnrecognized",
+ []byte{
+ // Target Link-Layer Address.
+ 2, 1, 1, 2, 3, 4, 5, 6,
+
+ // 255 is an unrecognized type. If 255 ends up
+ // being the type for some recognized type,
+ // update 255 to some other unrecognized value.
+ 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8,
+
+ // Prefix information.
+ 3, 4, 43, 64,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 21, 22, 23, 24,
+ },
+ nil,
+ },
+ {
+ "InvalidRecursiveDNSServerCutsOffAddress",
+ []byte{
+ 25, 4, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
+ 0, 1, 2, 3, 4, 5, 6, 7,
+ },
+ ErrNDPOptMalformedBody,
+ },
+ {
+ "InvalidRecursiveDNSServerInvalidLengthField",
+ []byte{
+ 25, 2, 0, 0,
+ 0, 0, 0, 0,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8,
+ },
+ ErrNDPInvalidLength,
+ },
+ {
+ "RecursiveDNSServerTooSmall",
+ []byte{
+ 25, 1, 0, 0,
+ 0, 0, 0,
+ },
+ ErrNDPOptBufExhausted,
+ },
+ {
+ "RecursiveDNSServerMulticast",
+ []byte{
+ 25, 3, 0, 0,
+ 0, 0, 0, 0,
+ 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
+ },
+ ErrNDPOptMalformedBody,
+ },
+ {
+ "RecursiveDNSServerUnspecified",
+ []byte{
+ 25, 3, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ },
+ ErrNDPOptMalformedBody,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := NDPOptions(test.buf)
+
+ if _, err := opts.Iter(true); err != test.expected {
+ t.Fatalf("got Iter(true) = (_, %v), want = (_, %v)", err, test.expected)
+ }
+
+ // test.buf may be malformed but we chose not to check
+ // the iterator so it must return true.
+ if _, err := opts.Iter(false); err != nil {
+ t.Fatalf("got Iter(false) = (_, %s), want = (_, nil)", err)
+ }
+ })
+ }
+}
+
+// TestNDPOptionsIter tests that we can iterator over a valid NDPOptions. Note,
+// this test does not actually check any of the option's getters, it simply
+// checks the option Type and Body. We have other tests that tests the option
+// field gettings given an option body and don't need to duplicate those tests
+// here.
+func TestNDPOptionsIter(t *testing.T) {
+ buf := []byte{
+ // Target Link-Layer Address.
+ 2, 1, 1, 2, 3, 4, 5, 6,
+
+ // 255 is an unrecognized type. If 255 ends up being the type
+ // for some recognized type, update 255 to some other
+ // unrecognized value. Note, this option should be skipped when
+ // iterating.
+ 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8,
+
+ // Prefix information.
+ 3, 4, 43, 64,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 21, 22, 23, 24,
+ }
+
+ opts := NDPOptions(buf)
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ // Test the first (Taret Link-Layer) option.
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) {
+ t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
+ }
+ if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType {
+ t.Errorf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType)
+ }
+
+ // Test the next (Prefix Information) option.
+ // Note, the unrecognized option should be skipped.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got, want := next.(NDPPrefixInformation), buf[26:][:30]; !bytes.Equal(got, want) {
+ t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
+ }
+ if got := next.Type(); got != NDPPrefixInformationType {
+ t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType)
+ }
+
+ // Iterator should not return anything else.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
+}
diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD
index 97a794986..7dbc05754 100644
--- a/pkg/tcpip/link/channel/BUILD
+++ b/pkg/tcpip/link/channel/BUILD
@@ -6,7 +6,7 @@ go_library(
name = "channel",
srcs = ["channel.go"],
importpath = "gvisor.dev/gvisor/pkg/tcpip/link/channel",
- visibility = ["//:sandbox"],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index 18adb2085..70188551f 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -25,10 +25,9 @@ import (
// PacketInfo holds all the information about an outbound packet.
type PacketInfo struct {
- Header buffer.View
- Payload buffer.View
- Proto tcpip.NetworkProtocolNumber
- GSO *stack.GSO
+ Pkt tcpip.PacketBuffer
+ Proto tcpip.NetworkProtocolNumber
+ GSO *stack.GSO
}
// Endpoint is link layer endpoint that stores outbound packets in a channel
@@ -65,14 +64,14 @@ func (e *Endpoint) Drain() int {
}
}
-// Inject injects an inbound packet.
-func (e *Endpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
- e.InjectLinkAddr(protocol, "", vv)
+// InjectInbound injects an inbound packet.
+func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+ e.InjectLinkAddr(protocol, "", pkt)
}
// InjectLinkAddr injects an inbound packet with a remote link address.
-func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, vv buffer.VectorisedView) {
- e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, vv.Clone(nil))
+func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt tcpip.PacketBuffer) {
+ e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt)
}
// Attach saves the stack network-layer dispatcher for use later when packets
@@ -96,7 +95,7 @@ func (e *Endpoint) MTU() uint32 {
func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
caps := stack.LinkEndpointCapabilities(0)
if e.GSO {
- caps |= stack.CapabilityGSO
+ caps |= stack.CapabilityHardwareGSO
}
return caps
}
@@ -118,12 +117,55 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
}
// WritePacket stores outbound packets into the channel.
-func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
p := PacketInfo{
- Header: hdr.View(),
- Proto: protocol,
- Payload: payload.ToView(),
- GSO: gso,
+ Pkt: pkt,
+ Proto: protocol,
+ GSO: gso,
+ }
+
+ select {
+ case e.C <- p:
+ default:
+ }
+
+ return nil
+}
+
+// WritePackets stores outbound packets into the channel.
+func (e *Endpoint) WritePackets(_ *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ payloadView := pkts[0].Data.ToView()
+ n := 0
+packetLoop:
+ for _, pkt := range pkts {
+ off := pkt.DataOffset
+ size := pkt.DataSize
+ p := PacketInfo{
+ Pkt: tcpip.PacketBuffer{
+ Header: pkt.Header,
+ Data: buffer.NewViewFromBytes(payloadView[off : off+size]).ToVectorisedView(),
+ },
+ Proto: protocol,
+ GSO: gso,
+ }
+
+ select {
+ case e.C <- p:
+ n++
+ default:
+ break packetLoop
+ }
+ }
+
+ return n, nil
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ p := PacketInfo{
+ Pkt: tcpip.PacketBuffer{Data: vv},
+ Proto: 0,
+ GSO: nil,
}
select {
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index 8fa9e3984..897c94821 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -14,9 +14,7 @@ go_library(
"packet_dispatchers.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index f80ac3435..fa8a703d9 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -165,6 +165,9 @@ type Options struct {
// disabled.
GSOMaxSize uint32
+ // SoftwareGSOEnabled indicates whether software GSO is enabled or not.
+ SoftwareGSOEnabled bool
+
// PacketDispatchMode specifies the type of inbound dispatcher to be
// used for this endpoint.
PacketDispatchMode PacketDispatchMode
@@ -242,7 +245,11 @@ func New(opts *Options) (stack.LinkEndpoint, error) {
}
if isSocket {
if opts.GSOMaxSize != 0 {
- e.caps |= stack.CapabilityGSO
+ if opts.SoftwareGSOEnabled {
+ e.caps |= stack.CapabilitySoftwareGSO
+ } else {
+ e.caps |= stack.CapabilityHardwareGSO
+ }
e.gsoMaxSize = opts.GSOMaxSize
}
}
@@ -379,10 +386,11 @@ const (
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
if e.hdrSize > 0 {
// Add ethernet header if needed.
- eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize))
+ eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
+ pkt.LinkHeader = buffer.View(eth)
ethHdr := &header.EthernetFields{
DstAddr: r.RemoteLinkAddress,
Type: protocol,
@@ -397,17 +405,17 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
eth.Encode(ethHdr)
}
- if e.Capabilities()&stack.CapabilityGSO != 0 {
+ if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
vnetHdr := virtioNetHdr{}
vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr)
if gso != nil {
- vnetHdr.hdrLen = uint16(hdr.UsedLength())
+ vnetHdr.hdrLen = uint16(pkt.Header.UsedLength())
if gso.NeedsCsum {
vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM
vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen
vnetHdr.csumOffset = gso.CsumOffset
}
- if gso.Type != stack.GSONone && uint16(payload.Size()) > gso.MSS {
+ if gso.Type != stack.GSONone && uint16(pkt.Data.Size()) > gso.MSS {
switch gso.Type {
case stack.GSOTCPv4:
vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4
@@ -420,18 +428,151 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
}
}
- return rawfile.NonBlockingWrite3(e.fds[0], vnetHdrBuf, hdr.View(), payload.ToView())
+ return rawfile.NonBlockingWrite3(e.fds[0], vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView())
+ }
+
+ if pkt.Data.Size() == 0 {
+ return rawfile.NonBlockingWrite(e.fds[0], pkt.Header.View())
}
- if payload.Size() == 0 {
- return rawfile.NonBlockingWrite(e.fds[0], hdr.View())
+ return rawfile.NonBlockingWrite3(e.fds[0], pkt.Header.View(), pkt.Data.ToView(), nil)
+}
+
+// WritePackets writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ var ethHdrBuf []byte
+ // hdr + data
+ iovLen := 2
+ if e.hdrSize > 0 {
+ // Add ethernet header if needed.
+ ethHdrBuf = make([]byte, header.EthernetMinimumSize)
+ eth := header.Ethernet(ethHdrBuf)
+ ethHdr := &header.EthernetFields{
+ DstAddr: r.RemoteLinkAddress,
+ Type: protocol,
+ }
+
+ // Preserve the src address if it's set in the route.
+ if r.LocalLinkAddress != "" {
+ ethHdr.SrcAddr = r.LocalLinkAddress
+ } else {
+ ethHdr.SrcAddr = e.addr
+ }
+ eth.Encode(ethHdr)
+ iovLen++
}
- return rawfile.NonBlockingWrite3(e.fds[0], hdr.View(), payload.ToView(), nil)
+ n := len(pkts)
+
+ views := pkts[0].Data.Views()
+ /*
+ * Each bondary in views can add one more iovec.
+ *
+ * payload | | | |
+ * -----------------------------
+ * packets | | | | | | |
+ * -----------------------------
+ * iovecs | | | | | | | | |
+ */
+ iovec := make([]syscall.Iovec, n*iovLen+len(views)-1)
+ mmsgHdrs := make([]rawfile.MMsgHdr, n)
+
+ iovecIdx := 0
+ viewIdx := 0
+ viewOff := 0
+ off := 0
+ nextOff := 0
+ for i := range pkts {
+ // TODO(b/134618279): Different packets may have different data
+ // in the future. We should handle this.
+ if !viewsEqual(pkts[i].Data.Views(), views) {
+ panic("All packets in pkts should have the same Data.")
+ }
+
+ prevIovecIdx := iovecIdx
+ mmsgHdr := &mmsgHdrs[i]
+ mmsgHdr.Msg.Iov = &iovec[iovecIdx]
+ packetSize := pkts[i].DataSize
+ hdr := &pkts[i].Header
+
+ off = pkts[i].DataOffset
+ if off != nextOff {
+ // We stop in a different point last time.
+ size := packetSize
+ viewIdx = 0
+ viewOff = 0
+ for size > 0 {
+ if size >= len(views[viewIdx]) {
+ viewIdx++
+ viewOff = 0
+ size -= len(views[viewIdx])
+ } else {
+ viewOff = size
+ size = 0
+ }
+ }
+ }
+ nextOff = off + packetSize
+
+ if ethHdrBuf != nil {
+ v := &iovec[iovecIdx]
+ v.Base = &ethHdrBuf[0]
+ v.Len = uint64(len(ethHdrBuf))
+ iovecIdx++
+ }
+
+ v := &iovec[iovecIdx]
+ hdrView := hdr.View()
+ v.Base = &hdrView[0]
+ v.Len = uint64(len(hdrView))
+ iovecIdx++
+
+ for packetSize > 0 {
+ vec := &iovec[iovecIdx]
+ iovecIdx++
+
+ v := views[viewIdx]
+ vec.Base = &v[viewOff]
+ s := len(v) - viewOff
+ if s <= packetSize {
+ viewIdx++
+ viewOff = 0
+ } else {
+ s = packetSize
+ viewOff += s
+ }
+ vec.Len = uint64(s)
+ packetSize -= s
+ }
+
+ mmsgHdr.Msg.Iovlen = uint64(iovecIdx - prevIovecIdx)
+ }
+
+ packets := 0
+ for packets < n {
+ sent, err := rawfile.NonBlockingSendMMsg(e.fds[0], mmsgHdrs)
+ if err != nil {
+ return packets, err
+ }
+ packets += sent
+ mmsgHdrs = mmsgHdrs[sent:]
+ }
+ return packets, nil
+}
+
+// viewsEqual tests whether v1 and v2 refer to the same backing bytes.
+func viewsEqual(vs1, vs2 []buffer.View) bool {
+ return len(vs1) == len(vs2) && (len(vs1) == 0 || &vs1[0] == &vs2[0])
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ return rawfile.NonBlockingWrite(e.fds[0], vv.ToView())
}
-// WriteRawPacket writes a raw packet directly to the file descriptor.
-func (e *endpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error {
+// InjectOutobund implements stack.InjectableEndpoint.InjectOutbound.
+func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error {
return rawfile.NonBlockingWrite(e.fds[0], packet)
}
@@ -468,9 +609,9 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
}
-// Inject injects an inbound packet.
-func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
- e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv)
+// InjectInbound injects an inbound packet.
+func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+ e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, pkt)
}
// NewInjectable creates a new fd-based InjectableEndpoint.
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 04406bc9a..2066987eb 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -45,7 +45,7 @@ const (
type packetInfo struct {
raddr tcpip.LinkAddress
proto tcpip.NetworkProtocolNumber
- contents buffer.View
+ contents tcpip.PacketBuffer
}
type context struct {
@@ -92,8 +92,8 @@ func (c *context) cleanup() {
syscall.Close(c.fds[1])
}
-func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
- c.ch <- packetInfo{remote, protocol, vv.ToView()}
+func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+ c.ch <- packetInfo{remote, protocol, pkt}
}
func TestNoEthernetProperties(t *testing.T) {
@@ -168,7 +168,10 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) {
L3HdrLen: header.IPv4MaximumHeaderSize,
}
}
- if err := c.ep.WritePacket(r, gso, hdr, payload.ToVectorisedView(), proto); err != nil {
+ if err := c.ep.WritePacket(r, gso, proto, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: payload.ToVectorisedView(),
+ }); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -258,7 +261,10 @@ func TestPreserveSrcAddress(t *testing.T) {
// WritePacket panics given a prependable with anything less than
// the minimum size of the ethernet header.
hdr := buffer.NewPrependable(header.EthernetMinimumSize)
- if err := c.ep.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, proto); err != nil {
+ if err := c.ep.WritePacket(r, nil /* gso */, proto, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: buffer.VectorisedView{},
+ }); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -293,11 +299,12 @@ func TestDeliverPacket(t *testing.T) {
b[i] = uint8(rand.Intn(256))
}
+ var hdr header.Ethernet
if !eth {
// So that it looks like an IPv4 packet.
b[0] = 0x40
} else {
- hdr := make(header.Ethernet, header.EthernetMinimumSize)
+ hdr = make(header.Ethernet, header.EthernetMinimumSize)
hdr.Encode(&header.EthernetFields{
SrcAddr: raddr,
DstAddr: laddr,
@@ -315,14 +322,21 @@ func TestDeliverPacket(t *testing.T) {
select {
case pi := <-c.ch:
want := packetInfo{
- raddr: raddr,
- proto: proto,
- contents: b,
+ raddr: raddr,
+ proto: proto,
+ contents: tcpip.PacketBuffer{
+ Data: buffer.View(b).ToVectorisedView(),
+ LinkHeader: buffer.View(hdr),
+ },
}
if !eth {
want.proto = header.IPv4ProtocolNumber
want.raddr = ""
}
+ // want.contents.Data will be a single
+ // view, so make pi do the same for the
+ // DeepEqual check.
+ pi.contents.Data = pi.contents.Data.ToView().ToVectorisedView()
if !reflect.DeepEqual(want, pi) {
t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want)
}
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
index 8bfeb97e4..62ed1e569 100644
--- a/pkg/tcpip/link/fdbased/mmap.go
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -169,9 +169,10 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
var (
p tcpip.NetworkProtocolNumber
remote, local tcpip.LinkAddress
+ eth header.Ethernet
)
if d.e.hdrSize > 0 {
- eth := header.Ethernet(pkt)
+ eth = header.Ethernet(pkt)
p = eth.Type()
remote = eth.SourceAddress()
local = eth.DestinationAddress()
@@ -189,6 +190,9 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
}
pkt = pkt[d.e.hdrSize:]
- d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)}))
+ d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, tcpip.PacketBuffer{
+ Data: buffer.View(pkt).ToVectorisedView(),
+ LinkHeader: buffer.View(eth),
+ })
return true, nil
}
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
index 7ca217e5b..c67d684ce 100644
--- a/pkg/tcpip/link/fdbased/packet_dispatchers.go
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -53,7 +53,7 @@ func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
d := &readVDispatcher{fd: fd, e: e}
d.views = make([]buffer.View, len(BufConfig))
iovLen := len(BufConfig)
- if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
iovLen++
}
d.iovecs = make([]syscall.Iovec, iovLen)
@@ -63,7 +63,7 @@ func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
func (d *readVDispatcher) allocateViews(bufConfig []int) {
var vnetHdr [virtioNetHdrSize]byte
vnetHdrOff := 0
- if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
// The kernel adds virtioNetHdr before each packet, but
// we don't use it, so so we allocate a buffer for it,
// add it in iovecs but don't add it in a view.
@@ -106,7 +106,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
if err != nil {
return false, err
}
- if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
// Skip virtioNetHdr which is added before each packet, it
// isn't used and it isn't in a view.
n -= virtioNetHdrSize
@@ -118,9 +118,10 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
var (
p tcpip.NetworkProtocolNumber
remote, local tcpip.LinkAddress
+ eth header.Ethernet
)
if d.e.hdrSize > 0 {
- eth := header.Ethernet(d.views[0])
+ eth = header.Ethernet(d.views[0][:header.EthernetMinimumSize])
p = eth.Type()
remote = eth.SourceAddress()
local = eth.DestinationAddress()
@@ -138,10 +139,13 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
}
used := d.capViews(n, BufConfig)
- vv := buffer.NewVectorisedView(n, d.views[:used])
- vv.TrimFront(d.e.hdrSize)
+ pkt := tcpip.PacketBuffer{
+ Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)),
+ LinkHeader: buffer.View(eth),
+ }
+ pkt.Data.TrimFront(d.e.hdrSize)
- d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv)
+ d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, pkt)
// Prepare e.views for another packet: release used views.
for i := 0; i < used; i++ {
@@ -194,7 +198,7 @@ func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
}
d.iovecs = make([][]syscall.Iovec, MaxMsgsPerRecv)
iovLen := len(BufConfig)
- if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
// virtioNetHdr is prepended before each packet.
iovLen++
}
@@ -225,7 +229,7 @@ func (d *recvMMsgDispatcher) allocateViews(bufConfig []int) {
for k := 0; k < len(d.views); k++ {
var vnetHdr [virtioNetHdrSize]byte
vnetHdrOff := 0
- if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
// The kernel adds virtioNetHdr before each packet, but
// we don't use it, so so we allocate a buffer for it,
// add it in iovecs but don't add it in a view.
@@ -261,7 +265,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
// Process each of received packets.
for k := 0; k < nMsgs; k++ {
n := int(d.msgHdrs[k].Len)
- if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
n -= virtioNetHdrSize
}
if n <= d.e.hdrSize {
@@ -271,9 +275,10 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
var (
p tcpip.NetworkProtocolNumber
remote, local tcpip.LinkAddress
+ eth header.Ethernet
)
if d.e.hdrSize > 0 {
- eth := header.Ethernet(d.views[k][0])
+ eth = header.Ethernet(d.views[k][0])
p = eth.Type()
remote = eth.SourceAddress()
local = eth.DestinationAddress()
@@ -291,9 +296,12 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
}
used := d.capViews(k, int(n), BufConfig)
- vv := buffer.NewVectorisedView(int(n), d.views[k][:used])
- vv.TrimFront(d.e.hdrSize)
- d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv)
+ pkt := tcpip.PacketBuffer{
+ Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)),
+ LinkHeader: buffer.View(eth),
+ }
+ pkt.Data.TrimFront(d.e.hdrSize)
+ d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, pkt)
// Prepare e.views for another packet: release used views.
for i := 0; i < used; i++ {
diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD
index 47a54845c..f35fcdff4 100644
--- a/pkg/tcpip/link/loopback/BUILD
+++ b/pkg/tcpip/link/loopback/BUILD
@@ -6,10 +6,11 @@ go_library(
name = "loopback",
srcs = ["loopback.go"],
importpath = "gvisor.dev/gvisor/pkg/tcpip/link/loopback",
- visibility = ["//:sandbox"],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index b36629d2c..499cc608f 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -23,6 +23,7 @@ package loopback
import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -70,21 +71,45 @@ func (*endpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (*endpoint) Wait() {}
+
// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound
// packets to the network-layer dispatcher.
-func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
- views := make([]buffer.View, 1, 1+len(payload.Views()))
- views[0] = hdr.View()
- views = append(views, payload.Views()...)
- vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
+func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+ views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
+ views[0] = pkt.Header.View()
+ views = append(views, pkt.Data.Views()...)
- // Because we're immediately turning around and writing the packet back to the
- // rx path, we intentionally don't preserve the remote and local link
- // addresses from the stack.Route we're passed.
- e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv)
+ // Because we're immediately turning around and writing the packet back
+ // to the rx path, we intentionally don't preserve the remote and local
+ // link addresses from the stack.Route we're passed.
+ e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, tcpip.PacketBuffer{
+ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
+ })
return nil
}
-// Wait implements stack.LinkEndpoint.Wait.
-func (*endpoint) Wait() {}
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ // Reject the packet if it's shorter than an ethernet header.
+ if vv.Size() < header.EthernetMinimumSize {
+ return tcpip.ErrBadAddress
+ }
+
+ // There should be an ethernet header at the beginning of vv.
+ linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize])
+ vv.TrimFront(len(linkHeader))
+ e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), tcpip.PacketBuffer{
+ Data: vv,
+ LinkHeader: buffer.View(linkHeader),
+ })
+
+ return nil
+}
diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD
index 1bab380b0..1ac7948b6 100644
--- a/pkg/tcpip/link/muxed/BUILD
+++ b/pkg/tcpip/link/muxed/BUILD
@@ -7,9 +7,7 @@ go_library(
name = "muxed",
srcs = ["injectable.go"],
importpath = "gvisor.dev/gvisor/pkg/tcpip/link/muxed",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index 7c946101d..445b22c17 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -79,29 +79,47 @@ func (m *InjectableEndpoint) IsAttached() bool {
return m.dispatcher != nil
}
-// Inject implements stack.InjectableLinkEndpoint.
-func (m *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
- m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, vv)
+// InjectInbound implements stack.InjectableLinkEndpoint.
+func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+ m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, pkt)
+}
+
+// WritePackets writes outbound packets to the appropriate
+// LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if
+// r.RemoteAddress has a route registered in this endpoint.
+func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ endpoint, ok := m.routes[r.RemoteAddress]
+ if !ok {
+ return 0, tcpip.ErrNoRoute
+ }
+ return endpoint.WritePackets(r, gso, pkts, protocol)
}
// WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint
// based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a
// route registered in this endpoint.
-func (m *InjectableEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
if endpoint, ok := m.routes[r.RemoteAddress]; ok {
- return endpoint.WritePacket(r, nil /* gso */, hdr, payload, protocol)
+ return endpoint.WritePacket(r, gso, protocol, pkt)
}
return tcpip.ErrNoRoute
}
-// WriteRawPacket writes outbound packets to the appropriate
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (m *InjectableEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
+ // WriteRawPacket doesn't get a route or network address, so there's
+ // nowhere to write this.
+ return tcpip.ErrNoRoute
+}
+
+// InjectOutbound writes outbound packets to the appropriate
// LinkInjectableEndpoint based on the dest address.
-func (m *InjectableEndpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error {
+func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error {
endpoint, ok := m.routes[dest]
if !ok {
return tcpip.ErrNoRoute
}
- return endpoint.WriteRawPacket(dest, packet)
+ return endpoint.InjectOutbound(dest, packet)
}
// Wait implements stack.LinkEndpoint.Wait.
diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go
index 3086fec00..63b249837 100644
--- a/pkg/tcpip/link/muxed/injectable_test.go
+++ b/pkg/tcpip/link/muxed/injectable_test.go
@@ -31,7 +31,7 @@ import (
func TestInjectableEndpointRawDispatch(t *testing.T) {
endpoint, sock, dstIP := makeTestInjectableEndpoint(t)
- endpoint.WriteRawPacket(dstIP, []byte{0xFA})
+ endpoint.InjectOutbound(dstIP, []byte{0xFA})
buf := make([]byte, ipv4.MaxTotalSize)
bytesRead, err := sock.Read(buf)
@@ -50,8 +50,10 @@ func TestInjectableEndpointDispatch(t *testing.T) {
hdr.Prepend(1)[0] = 0xFA
packetRoute := stack.Route{RemoteAddress: dstIP}
- endpoint.WritePacket(&packetRoute, nil /* gso */, hdr,
- buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), ipv4.ProtocolNumber)
+ endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(),
+ })
buf := make([]byte, 6500)
bytesRead, err := sock.Read(buf)
@@ -68,8 +70,10 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) {
hdr := buffer.NewPrependable(1)
hdr.Prepend(1)[0] = 0xFA
packetRoute := stack.Route{RemoteAddress: dstIP}
- endpoint.WritePacket(&packetRoute, nil /* gso */, hdr,
- buffer.NewView(0).ToVectorisedView(), ipv4.ProtocolNumber)
+ endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: buffer.NewView(0).ToVectorisedView(),
+ })
buf := make([]byte, 6500)
bytesRead, err := sock.Read(buf)
if err != nil {
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
index 2e8bc772a..d8211e93d 100644
--- a/pkg/tcpip/link/rawfile/BUILD
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -13,8 +13,9 @@ go_library(
"rawfile_unsafe.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/link/rawfile",
- visibility = [
- "//visibility:public",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "@org_golang_x_sys//unix:go_default_library",
],
- deps = ["//pkg/tcpip"],
)
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
index dda3b10a6..0b5a6cf49 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
+++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
@@ -14,7 +14,7 @@
// +build linux,amd64 linux,arm64
// +build go1.12
-// +build !go1.14
+// +build !go1.15
// Check go:linkname function signatures when updating Go version.
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index 7e286a3a6..44e25d475 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -22,6 +22,7 @@ import (
"syscall"
"unsafe"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -101,6 +102,16 @@ func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error {
return nil
}
+// NonBlockingSendMMsg sends multiple messages on a socket.
+func NonBlockingSendMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) {
+ n, _, e := syscall.RawSyscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0)
+ if e != 0 {
+ return 0, TranslateErrno(e)
+ }
+
+ return int(n), nil
+}
+
// PollEvent represents the pollfd structure passed to a poll() system call.
type PollEvent struct {
FD int32
diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD
index 0a5ea3dc4..a4f9cdd69 100644
--- a/pkg/tcpip/link/sharedmem/BUILD
+++ b/pkg/tcpip/link/sharedmem/BUILD
@@ -12,9 +12,7 @@ go_library(
"tx.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem",
- visibility = [
- "//:sandbox",
- ],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/log",
"//pkg/tcpip",
diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD
index 330ed5e94..6b5bc542c 100644
--- a/pkg/tcpip/link/sharedmem/pipe/BUILD
+++ b/pkg/tcpip/link/sharedmem/pipe/BUILD
@@ -12,7 +12,7 @@ go_library(
"tx.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe",
- visibility = ["//:sandbox"],
+ visibility = ["//visibility:public"],
)
go_test(
diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD
index de1ce043d..8c9234d54 100644
--- a/pkg/tcpip/link/sharedmem/queue/BUILD
+++ b/pkg/tcpip/link/sharedmem/queue/BUILD
@@ -10,7 +10,7 @@ go_library(
"tx.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue",
- visibility = ["//:sandbox"],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/log",
"//pkg/tcpip/link/sharedmem/pipe",
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 9e71d4edf..080f9d667 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -185,9 +185,10 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
// Add the ethernet header here.
- eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize))
+ eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
+ pkt.LinkHeader = buffer.View(eth)
ethHdr := &header.EthernetFields{
DstAddr: r.RemoteLinkAddress,
Type: protocol,
@@ -199,10 +200,30 @@ func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependa
}
eth.Encode(ethHdr)
- v := payload.ToView()
+ v := pkt.Data.ToView()
// Transmit the packet.
e.mu.Lock()
- ok := e.tx.transmit(hdr.View(), v)
+ ok := e.tx.transmit(pkt.Header.View(), v)
+ e.mu.Unlock()
+
+ if !ok {
+ return tcpip.ErrWouldBlock
+ }
+
+ return nil
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ v := vv.ToView()
+ // Transmit the packet.
+ e.mu.Lock()
+ ok := e.tx.transmit(v, buffer.View{})
e.mu.Unlock()
if !ok {
@@ -253,8 +274,11 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
}
// Send packet up the stack.
- eth := header.Ethernet(b)
- d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView())
+ eth := header.Ethernet(b[:header.EthernetMinimumSize])
+ d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), tcpip.PacketBuffer{
+ Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(),
+ LinkHeader: buffer.View(eth),
+ })
}
// Clean state.
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 0e9ba0846..89603c48f 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -78,9 +78,10 @@ func (q *queueBuffers) cleanup() {
}
type packetInfo struct {
- addr tcpip.LinkAddress
- proto tcpip.NetworkProtocolNumber
- vv buffer.VectorisedView
+ addr tcpip.LinkAddress
+ proto tcpip.NetworkProtocolNumber
+ vv buffer.VectorisedView
+ linkHeader buffer.View
}
type testContext struct {
@@ -130,12 +131,12 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress
return c
}
-func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
c.mu.Lock()
c.packets = append(c.packets, packetInfo{
addr: remoteLinkAddr,
proto: proto,
- vv: vv.Clone(nil),
+ vv: pkt.Data.Clone(nil),
})
c.mu.Unlock()
@@ -272,7 +273,10 @@ func TestSimpleSend(t *testing.T) {
randomFill(buf)
proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
- if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), proto); err != nil {
+ if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: buf.ToVectorisedView(),
+ }); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -341,7 +345,9 @@ func TestPreserveSrcAddressInSend(t *testing.T) {
hdr := buffer.NewPrependable(header.EthernetMinimumSize)
proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
- if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buffer.VectorisedView{}, proto); err != nil {
+ if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{
+ Header: hdr,
+ }); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -395,7 +401,10 @@ func TestFillTxQueue(t *testing.T) {
for i := queuePipeSize / 40; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil {
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: buf.ToVectorisedView(),
+ }); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
@@ -410,7 +419,10 @@ func TestFillTxQueue(t *testing.T) {
// Next attempt to write must fail.
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want {
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: buf.ToVectorisedView(),
+ }); err != want {
t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
}
}
@@ -435,7 +447,10 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
// Send two packets so that the id slice has at least two slots.
for i := 2; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil {
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: buf.ToVectorisedView(),
+ }); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
}
@@ -455,7 +470,10 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
ids := make(map[uint64]struct{})
for i := queuePipeSize / 40; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil {
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: buf.ToVectorisedView(),
+ }); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
@@ -470,7 +488,10 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
// Next attempt to write must fail.
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want {
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: buf.ToVectorisedView(),
+ }); err != want {
t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
}
}
@@ -493,7 +514,10 @@ func TestFillTxMemory(t *testing.T) {
ids := make(map[uint64]struct{})
for i := queueDataSize / bufferSize; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil {
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: buf.ToVectorisedView(),
+ }); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
@@ -509,7 +533,10 @@ func TestFillTxMemory(t *testing.T) {
// Next attempt to write must fail.
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber)
+ err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: buf.ToVectorisedView(),
+ })
if want := tcpip.ErrWouldBlock; err != want {
t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
}
@@ -534,7 +561,10 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
// until there is only one buffer left.
for i := queueDataSize/bufferSize - 1; i > 0; i-- {
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil {
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: buf.ToVectorisedView(),
+ }); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
@@ -547,7 +577,10 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
{
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
uu := buffer.NewView(bufferSize).ToVectorisedView()
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, uu, header.IPv4ProtocolNumber); err != want {
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: uu,
+ }); err != want {
t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
}
}
@@ -555,7 +588,10 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
// Attempt to write the one-buffer packet again. It must succeed.
{
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil {
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: buf.ToVectorisedView(),
+ }); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
}
diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD
index 1756114e6..d6ae0368a 100644
--- a/pkg/tcpip/link/sniffer/BUILD
+++ b/pkg/tcpip/link/sniffer/BUILD
@@ -9,9 +9,7 @@ go_library(
"sniffer.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sniffer",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/log",
"//pkg/tcpip",
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index e401dce44..3392b7edd 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -49,6 +49,13 @@ var LogPackets uint32 = 1
// LogPacketsToFile must be accessed atomically.
var LogPacketsToFile uint32 = 1
+var transportProtocolMinSizes map[tcpip.TransportProtocolNumber]int = map[tcpip.TransportProtocolNumber]int{
+ header.ICMPv4ProtocolNumber: header.IPv4MinimumSize,
+ header.ICMPv6ProtocolNumber: header.IPv6MinimumSize,
+ header.UDPProtocolNumber: header.UDPMinimumSize,
+ header.TCPProtocolNumber: header.TCPMinimumSize,
+}
+
type endpoint struct {
dispatcher stack.NetworkDispatcher
lower stack.LinkEndpoint
@@ -116,19 +123,19 @@ func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack
// DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is
// called by the link-layer endpoint being wrapped when a packet arrives, and
// logs the packet before forwarding to the actual dispatcher.
-func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
- logPacket("recv", protocol, vv.First(), nil)
+ logPacket("recv", protocol, pkt.Data.First(), nil)
}
if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
- vs := vv.Views()
- length := vv.Size()
+ vs := pkt.Data.Views()
+ length := pkt.Data.Size()
if length > int(e.maxPCAPLen) {
length = int(e.maxPCAPLen)
}
buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length))
- if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))); err != nil {
+ if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(pkt.Data.Size()))); err != nil {
panic(err)
}
for _, v := range vs {
@@ -147,7 +154,7 @@ func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local
panic(err)
}
}
- e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv)
+ e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt)
}
// Attach implements the stack.LinkEndpoint interface. It saves the dispatcher
@@ -193,22 +200,19 @@ func (e *endpoint) GSOMaxSize() uint32 {
return 0
}
-// WritePacket implements the stack.LinkEndpoint interface. It is called by
-// higher-level protocols to write packets; it just logs the packet and forwards
-// the request to the lower endpoint.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
- logPacket("send", protocol, hdr.View(), gso)
+ logPacket("send", protocol, pkt.Header.View(), gso)
}
if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
- hdrBuf := hdr.View()
- length := len(hdrBuf) + payload.Size()
+ hdrBuf := pkt.Header.View()
+ length := len(hdrBuf) + pkt.Data.Size()
if length > int(e.maxPCAPLen) {
length = int(e.maxPCAPLen)
}
buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length))
- if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(len(hdrBuf)+payload.Size()))); err != nil {
+ if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(len(hdrBuf)+pkt.Data.Size()))); err != nil {
panic(err)
}
if len(hdrBuf) > length {
@@ -218,26 +222,75 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
panic(err)
}
length -= len(hdrBuf)
- if length > 0 {
- for _, v := range payload.Views() {
- if len(v) > length {
- v = v[:length]
- }
- n, err := buf.Write(v)
- if err != nil {
- panic(err)
- }
- length -= n
- if length == 0 {
- break
- }
- }
+ logVectorisedView(pkt.Data, length, buf)
+ if _, err := e.file.Write(buf.Bytes()); err != nil {
+ panic(err)
}
+ }
+}
+
+// WritePacket implements the stack.LinkEndpoint interface. It is called by
+// higher-level protocols to write packets; it just logs the packet and
+// forwards the request to the lower endpoint.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+ e.dumpPacket(gso, protocol, pkt)
+ return e.lower.WritePacket(r, gso, protocol, pkt)
+}
+
+// WritePackets implements the stack.LinkEndpoint interface. It is called by
+// higher-level protocols to write packets; it just logs the packet and
+// forwards the request to the lower endpoint.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ view := pkts[0].Data.ToView()
+ for _, pkt := range pkts {
+ e.dumpPacket(gso, protocol, tcpip.PacketBuffer{
+ Header: pkt.Header,
+ Data: view[pkt.DataOffset:][:pkt.DataSize].ToVectorisedView(),
+ })
+ }
+ return e.lower.WritePackets(r, gso, pkts, protocol)
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
+ logPacket("send", 0, buffer.View("[raw packet, no header available]"), nil /* gso */)
+ }
+ if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
+ length := vv.Size()
+ if length > int(e.maxPCAPLen) {
+ length = int(e.maxPCAPLen)
+ }
+
+ buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length))
+ if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))); err != nil {
+ panic(err)
+ }
+ logVectorisedView(vv, length, buf)
if _, err := e.file.Write(buf.Bytes()); err != nil {
panic(err)
}
}
- return e.lower.WritePacket(r, gso, hdr, payload, protocol)
+ return e.lower.WriteRawPacket(vv)
+}
+
+func logVectorisedView(vv buffer.VectorisedView, length int, buf *bytes.Buffer) {
+ if length <= 0 {
+ return
+ }
+ for _, v := range vv.Views() {
+ if len(v) > length {
+ v = v[:length]
+ }
+ n, err := buf.Write(v)
+ if err != nil {
+ panic(err)
+ }
+ length -= n
+ if length == 0 {
+ return
+ }
+ }
}
// Wait implements stack.LinkEndpoint.Wait.
@@ -287,6 +340,13 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
return
}
+ // We aren't guaranteed to have a transport header - it's possible for
+ // writes via raw endpoints to contain only network headers.
+ if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize {
+ log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto)
+ return
+ }
+
// Figure out the transport layer info.
transName := "unknown"
srcPort := uint16(0)
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
index 92dce8fac..a71a493fc 100644
--- a/pkg/tcpip/link/tun/BUILD
+++ b/pkg/tcpip/link/tun/BUILD
@@ -6,7 +6,5 @@ go_library(
name = "tun",
srcs = ["tun_unsafe.go"],
importpath = "gvisor.dev/gvisor/pkg/tcpip/link/tun",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
)
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD
index 0746dc8ec..134837943 100644
--- a/pkg/tcpip/link/waitable/BUILD
+++ b/pkg/tcpip/link/waitable/BUILD
@@ -9,9 +9,7 @@ go_library(
"waitable.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/link/waitable",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/gate",
"//pkg/tcpip",
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index 5a1791cb5..a8de38979 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -50,12 +50,12 @@ func New(lower stack.LinkEndpoint) *Endpoint {
// It is called by the link-layer endpoint being wrapped when a packet arrives,
// and only forwards to the actual dispatcher if Wait or WaitDispatch haven't
// been called.
-func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
if !e.dispatchGate.Enter() {
return
}
- e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv)
+ e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt)
e.dispatchGate.Leave()
}
@@ -99,12 +99,36 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
// WritePacket implements stack.LinkEndpoint.WritePacket. It is called by
// higher-level protocols to write packets. It only forwards packets to the
// lower endpoint if Wait or WaitWrite haven't been called.
-func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
if !e.writeGate.Enter() {
return nil
}
- err := e.lower.WritePacket(r, gso, hdr, payload, protocol)
+ err := e.lower.WritePacket(r, gso, protocol, pkt)
+ e.writeGate.Leave()
+ return err
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets. It is called by
+// higher-level protocols to write packets. It only forwards packets to the
+// lower endpoint if Wait or WaitWrite haven't been called.
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ if !e.writeGate.Enter() {
+ return len(pkts), nil
+ }
+
+ n, err := e.lower.WritePackets(r, gso, pkts, protocol)
+ e.writeGate.Leave()
+ return n, err
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ if !e.writeGate.Enter() {
+ return nil
+ }
+
+ err := e.lower.WriteRawPacket(vv)
e.writeGate.Leave()
return err
}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index ae23c96b7..31b11a27a 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -35,7 +35,7 @@ type countedEndpoint struct {
dispatcher stack.NetworkDispatcher
}
-func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
e.dispatchCount++
}
@@ -65,7 +65,18 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
-func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+ e.writeCount++
+ return nil
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ e.writeCount += len(pkts)
+ return len(pkts), nil
+}
+
+func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
e.writeCount++
return nil
}
@@ -78,21 +89,21 @@ func TestWaitWrite(t *testing.T) {
wep := New(ep)
// Write and check that it goes through.
- wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0)
+ wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{})
if want := 1; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
// Wait on dispatches, then try to write. It must go through.
wep.WaitDispatch()
- wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0)
+ wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{})
if want := 2; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
// Wait on writes, then try to write. It must not go through.
wep.WaitWrite()
- wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0)
+ wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{})
if want := 2; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
@@ -109,21 +120,21 @@ func TestWaitDispatch(t *testing.T) {
}
// Dispatch and check that it goes through.
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{})
+ ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{})
if want := 1; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
// Wait on writes, then try to dispatch. It must go through.
wep.WaitWrite()
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{})
+ ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{})
if want := 2; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
// Wait on dispatches, then try to dispatch. It must not go through.
wep.WaitDispatch()
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{})
+ ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{})
if want := 2; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index df0d3a8c0..e7617229b 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -7,9 +7,7 @@ go_library(
name = "arp",
srcs = ["arp.go"],
importpath = "gvisor.dev/gvisor/pkg/tcpip/network/arp",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 6b1e854dc..da8482509 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -42,7 +42,7 @@ const (
// endpoint implements stack.NetworkEndpoint.
type endpoint struct {
- nicid tcpip.NICID
+ nicID tcpip.NICID
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
}
@@ -58,7 +58,7 @@ func (e *endpoint) MTU() uint32 {
}
func (e *endpoint) NICID() tcpip.NICID {
- return e.nicid
+ return e.nicID
}
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
@@ -79,16 +79,21 @@ func (e *endpoint) MaxHeaderLength() uint16 {
func (e *endpoint) Close() {}
-func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, stack.NetworkHeaderParams, stack.PacketLooping) *tcpip.Error {
+func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketLooping, tcpip.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
+// WritePackets implements stack.NetworkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams, stack.PacketLooping) (int, *tcpip.Error) {
+ return 0, tcpip.ErrNotSupported
+}
+
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
-func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
- v := vv.First()
+func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
+ v := pkt.Data.First()
h := header.ARP(v)
if !h.IsValid() {
return
@@ -97,23 +102,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
switch h.Op() {
case header.ARPRequest:
localAddr := tcpip.Address(h.ProtocolAddressTarget())
- if e.linkAddrCache.CheckLocalAddress(e.nicid, header.IPv4ProtocolNumber, localAddr) == 0 {
+ if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
return // we have no useful answer, ignore the request
}
hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize)
- pkt := header.ARP(hdr.Prepend(header.ARPSize))
- pkt.SetIPv4OverEthernet()
- pkt.SetOp(header.ARPReply)
- copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:])
- copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget())
- copy(pkt.HardwareAddressTarget(), h.HardwareAddressSender())
- copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender())
- e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
+ packet := header.ARP(hdr.Prepend(header.ARPSize))
+ packet.SetIPv4OverEthernet()
+ packet.SetOp(header.ARPReply)
+ copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:])
+ copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget())
+ copy(packet.HardwareAddressTarget(), h.HardwareAddressSender())
+ copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender())
+ e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{
+ Header: hdr,
+ })
fallthrough // also fill the cache from requests
case header.ARPReply:
addr := tcpip.Address(h.ProtocolAddressSender())
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr)
}
}
@@ -130,12 +137,12 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
}
-func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
if addrWithPrefix.Address != ProtocolAddress {
return nil, tcpip.ErrBadLocalAddress
}
return &endpoint{
- nicid: nicid,
+ nicID: nicID,
linkEP: sender,
linkAddrCache: linkAddrCache,
}, nil
@@ -160,7 +167,9 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
copy(h.ProtocolAddressSender(), localAddr)
copy(h.ProtocolAddressTarget(), addr)
- return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
+ return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{
+ Header: hdr,
+ })
}
// ResolveStaticAddress implements stack.LinkAddressResolver.
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 88b57ec03..8e6048a21 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -102,19 +102,21 @@ func TestDirectRequest(t *testing.T) {
inject := func(addr tcpip.Address) {
copy(h.ProtocolAddressTarget(), addr)
- c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView())
+ c.linkEP.InjectInbound(arp.ProtocolNumber, tcpip.PacketBuffer{
+ Data: v.ToVectorisedView(),
+ })
}
for i, address := range []tcpip.Address{stackAddr1, stackAddr2} {
t.Run(strconv.Itoa(i), func(t *testing.T) {
inject(address)
- pkt := <-c.linkEP.C
- if pkt.Proto != arp.ProtocolNumber {
- t.Fatalf("expected ARP response, got network protocol number %d", pkt.Proto)
+ pi := <-c.linkEP.C
+ if pi.Proto != arp.ProtocolNumber {
+ t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto)
}
- rep := header.ARP(pkt.Header)
+ rep := header.ARP(pi.Pkt.Header.View())
if !rep.IsValid() {
- t.Fatalf("invalid ARP response len(pkt.Header)=%d", len(pkt.Header))
+ t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength())
}
if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
t.Errorf("got HardwareAddressSender = %s, want = %s", got, want)
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index 2cad0a0b6..acf1e022c 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -25,7 +25,7 @@ go_library(
"reassembler_list.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation",
- visibility = ["//:sandbox"],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/log",
"//pkg/tcpip",
@@ -44,11 +44,3 @@ go_test(
embed = [":fragmentation"],
deps = ["//pkg/tcpip/buffer"],
)
-
-filegroup(
- name = "autogen",
- srcs = [
- "reassembler_list.go",
- ],
- visibility = ["//:sandbox"],
-)
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index f644a8b08..4144a7837 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -96,16 +96,16 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff
// DeliverTransportPacket is called by network endpoints after parsing incoming
// packets. This is used by the test object to verify that the results of the
// parsing are expected.
-func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) {
- t.checkValues(protocol, vv, r.RemoteAddress, r.LocalAddress)
+func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) {
+ t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress)
t.dataCalls++
}
// DeliverTransportControlPacket is called by network endpoints after parsing
// incoming control (ICMP) packets. This is used by the test object to verify
// that the results of the parsing are expected.
-func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
- t.checkValues(trans, vv, remote, local)
+func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+ t.checkValues(trans, pkt.Data, remote, local)
if typ != t.typ {
t.t.Errorf("typ = %v, want %v", typ, t.typ)
}
@@ -150,27 +150,36 @@ func (*testObject) Wait() {}
// WritePacket is called by network endpoints after producing a packet and
// writing it to the link endpoint. This is used by the test object to verify
// that the produced packet is as expected.
-func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
var prot tcpip.TransportProtocolNumber
var srcAddr tcpip.Address
var dstAddr tcpip.Address
if t.v4 {
- h := header.IPv4(hdr.View())
+ h := header.IPv4(pkt.Header.View())
prot = tcpip.TransportProtocolNumber(h.Protocol())
srcAddr = h.SourceAddress()
dstAddr = h.DestinationAddress()
} else {
- h := header.IPv6(hdr.View())
+ h := header.IPv6(pkt.Header.View())
prot = tcpip.TransportProtocolNumber(h.NextHeader())
srcAddr = h.SourceAddress()
dstAddr = h.DestinationAddress()
}
- t.checkValues(prot, payload, srcAddr, dstAddr)
+ t.checkValues(prot, pkt.Data, srcAddr, dstAddr)
return nil
}
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
@@ -230,7 +239,10 @@ func TestIPv4Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut); err != nil {
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: payload.ToVectorisedView(),
+ }); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
}
@@ -270,7 +282,9 @@ func TestIPv4Receive(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- ep.HandlePacket(&r, view.ToVectorisedView())
+ ep.HandlePacket(&r, tcpip.PacketBuffer{
+ Data: view.ToVectorisedView(),
+ })
if o.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
}
@@ -358,7 +372,9 @@ func TestIPv4ReceiveControl(t *testing.T) {
o.extra = c.expectedExtra
vv := view[:len(view)-c.trunc].ToVectorisedView()
- ep.HandlePacket(&r, vv)
+ ep.HandlePacket(&r, tcpip.PacketBuffer{
+ Data: vv,
+ })
if want := c.expectedCount; o.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
}
@@ -421,13 +437,17 @@ func TestIPv4FragmentationReceive(t *testing.T) {
}
// Send first segment.
- ep.HandlePacket(&r, frag1.ToVectorisedView())
+ ep.HandlePacket(&r, tcpip.PacketBuffer{
+ Data: frag1.ToVectorisedView(),
+ })
if o.dataCalls != 0 {
t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls)
}
// Send second segment.
- ep.HandlePacket(&r, frag2.ToVectorisedView())
+ ep.HandlePacket(&r, tcpip.PacketBuffer{
+ Data: frag2.ToVectorisedView(),
+ })
if o.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
}
@@ -460,7 +480,10 @@ func TestIPv6Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut); err != nil {
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: payload.ToVectorisedView(),
+ }); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
}
@@ -500,7 +523,9 @@ func TestIPv6Receive(t *testing.T) {
t.Fatalf("could not find route: %v", err)
}
- ep.HandlePacket(&r, view.ToVectorisedView())
+ ep.HandlePacket(&r, tcpip.PacketBuffer{
+ Data: view.ToVectorisedView(),
+ })
if o.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
}
@@ -510,6 +535,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
newUint16 := func(v uint16) *uint16 { return &v }
const mtu = 0xffff
+ const outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa"
cases := []struct {
name string
expectedCount int
@@ -561,7 +587,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: 20,
- SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa",
+ SrcAddr: outerSrcAddr,
DstAddr: localIpv6Addr,
})
@@ -608,8 +634,12 @@ func TestIPv6ReceiveControl(t *testing.T) {
o.typ = c.expectedTyp
o.extra = c.expectedExtra
- vv := view[:len(view)-c.trunc].ToVectorisedView()
- ep.HandlePacket(&r, vv)
+ // Set ICMPv6 checksum.
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{}))
+
+ ep.HandlePacket(&r, tcpip.PacketBuffer{
+ Data: view[:len(view)-c.trunc].ToVectorisedView(),
+ })
if want := c.expectedCount; o.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 58e537aad..aeddfcdd4 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -10,9 +10,7 @@ go_library(
"ipv4.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv4",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 50b363dc4..32bf39e43 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -15,6 +15,7 @@
package ipv4
import (
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -24,8 +25,8 @@ import (
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
// packet.
-func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
- h := header.IPv4(vv.First())
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+ h := header.IPv4(pkt.Data.First())
// We don't use IsValid() here because ICMP only requires that the IP
// header plus 8 bytes of the transport header be included. So it's
@@ -39,7 +40,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.
}
hlen := int(h.HeaderLength())
- if vv.Size() < hlen || h.FragmentOffset() != 0 {
+ if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 {
// We won't be able to handle this if it doesn't contain the
// full IPv4 header, or if it's a fragment not at offset 0
// (because it won't have the transport header).
@@ -47,15 +48,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.
}
// Skip the ip header, then deliver control message.
- vv.TrimFront(hlen)
+ pkt.Data.TrimFront(hlen)
p := h.TransportProtocol()
- e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv)
+ e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
-func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
+func (e *endpoint) handleICMP(r *stack.Route, pkt tcpip.PacketBuffer) {
stats := r.Stats()
received := stats.ICMP.V4PacketsReceived
- v := vv.First()
+ v := pkt.Data.First()
if len(v) < header.ICMPv4MinimumSize {
received.Invalid.Increment()
return
@@ -73,20 +74,23 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
// checksum. We'll have to reset this before we hand the packet
// off.
h.SetChecksum(0)
- gotChecksum := ^header.ChecksumVV(vv, 0 /* initial */)
+ gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
if gotChecksum != wantChecksum {
// It's possible that a raw socket expects to receive this.
h.SetChecksum(wantChecksum)
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
received.Invalid.Increment()
return
}
// It's possible that a raw socket expects to receive this.
h.SetChecksum(wantChecksum)
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, tcpip.PacketBuffer{
+ Data: pkt.Data.Clone(nil),
+ NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...),
+ })
- vv := vv.Clone(nil)
+ vv := pkt.Data.Clone(nil)
vv.TrimFront(header.ICMPv4MinimumSize)
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize)
pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
@@ -95,7 +99,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
pkt.SetChecksum(0)
pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0)))
sent := stats.ICMP.V4PacketsSent
- if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil {
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: vv,
+ TransportHeader: buffer.View(pkt),
+ }); err != nil {
sent.Dropped.Increment()
return
}
@@ -104,19 +112,19 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
case header.ICMPv4EchoReply:
received.EchoReply.Increment()
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
case header.ICMPv4DstUnreachable:
received.DstUnreachable.Increment()
- vv.TrimFront(header.ICMPv4MinimumSize)
+ pkt.Data.TrimFront(header.ICMPv4MinimumSize)
switch h.Code() {
case header.ICMPv4PortUnreachable:
- e.handleControl(stack.ControlPortUnreachable, 0, vv)
+ e.handleControl(stack.ControlPortUnreachable, 0, pkt)
case header.ICMPv4FragmentationNeeded:
mtu := uint32(h.MTU())
- e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
+ e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt)
}
case header.ICMPv4SrcQuench:
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 5cd895ff0..e645cf62c 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -47,7 +47,7 @@ const (
)
type endpoint struct {
- nicid tcpip.NICID
+ nicID tcpip.NICID
id stack.NetworkEndpointID
prefixLen int
linkEP stack.LinkEndpoint
@@ -57,9 +57,9 @@ type endpoint struct {
}
// NewEndpoint creates a new ipv4 endpoint.
-func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
e := &endpoint{
- nicid: nicid,
+ nicID: nicID,
id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
prefixLen: addrWithPrefix.PrefixLen,
linkEP: linkEP,
@@ -89,7 +89,7 @@ func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
// NICID returns the ID of the NIC this endpoint belongs to.
func (e *endpoint) NICID() tcpip.NICID {
- return e.nicid
+ return e.nicID
}
// ID returns the ipv4 endpoint ID.
@@ -117,13 +117,14 @@ func (e *endpoint) GSOMaxSize() uint32 {
}
// writePacketFragments calls e.linkEP.WritePacket with each packet fragment to
-// write. It assumes that the IP header is entirely in hdr but does not assume
-// that only the IP header is in hdr. It assumes that the input packet's stated
-// length matches the length of the hdr+payload. mtu includes the IP header and
-// options. This does not support the DontFragment IP flag.
-func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, mtu int) *tcpip.Error {
+// write. It assumes that the IP header is entirely in pkt.Header but does not
+// assume that only the IP header is in pkt.Header. It assumes that the input
+// packet's stated length matches the length of the header+payload. mtu
+// includes the IP header and options. This does not support the DontFragment
+// IP flag.
+func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt tcpip.PacketBuffer) *tcpip.Error {
// This packet is too big, it needs to be fragmented.
- ip := header.IPv4(hdr.View())
+ ip := header.IPv4(pkt.Header.View())
flags := ip.Flags()
// Update mtu to take into account the header, which will exist in all
@@ -137,71 +138,85 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buff
outerMTU := innerMTU + int(ip.HeaderLength())
offset := ip.FragmentOffset()
- originalAvailableLength := hdr.AvailableLength()
+ originalAvailableLength := pkt.Header.AvailableLength()
for i := 0; i < n; i++ {
// Where possible, the first fragment that is sent has the same
- // hdr.UsedLength() as the input packet. The link-layer endpoint may depends
- // on this for looking at, eg, L4 headers.
+ // pkt.Header.UsedLength() as the input packet. The link-layer
+ // endpoint may depend on this for looking at, eg, L4 headers.
h := ip
if i > 0 {
- hdr = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength)
- h = header.IPv4(hdr.Prepend(int(ip.HeaderLength())))
+ pkt.Header = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength)
+ h = header.IPv4(pkt.Header.Prepend(int(ip.HeaderLength())))
copy(h, ip[:ip.HeaderLength()])
}
if i != n-1 {
h.SetTotalLength(uint16(outerMTU))
h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset)
} else {
- h.SetTotalLength(uint16(h.HeaderLength()) + uint16(payload.Size()))
+ h.SetTotalLength(uint16(h.HeaderLength()) + uint16(pkt.Data.Size()))
h.SetFlagsFragmentOffset(flags, offset)
}
h.SetChecksum(0)
h.SetChecksum(^h.CalculateChecksum())
offset += uint16(innerMTU)
if i > 0 {
- newPayload := payload.Clone([]buffer.View{})
+ newPayload := pkt.Data.Clone(nil)
newPayload.CapLength(innerMTU)
- if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil {
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{
+ Header: pkt.Header,
+ Data: newPayload,
+ NetworkHeader: buffer.View(h),
+ }); err != nil {
return err
}
r.Stats().IP.PacketsSent.Increment()
- payload.TrimFront(newPayload.Size())
+ pkt.Data.TrimFront(newPayload.Size())
continue
}
- // Special handling for the first fragment because it comes from the hdr.
- if outerMTU >= hdr.UsedLength() {
- // This fragment can fit all of hdr and possibly some of payload, too.
- newPayload := payload.Clone([]buffer.View{})
- newPayloadLength := outerMTU - hdr.UsedLength()
+ // Special handling for the first fragment because it comes
+ // from the header.
+ if outerMTU >= pkt.Header.UsedLength() {
+ // This fragment can fit all of pkt.Header and possibly
+ // some of pkt.Data, too.
+ newPayload := pkt.Data.Clone(nil)
+ newPayloadLength := outerMTU - pkt.Header.UsedLength()
newPayload.CapLength(newPayloadLength)
- if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil {
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{
+ Header: pkt.Header,
+ Data: newPayload,
+ NetworkHeader: buffer.View(h),
+ }); err != nil {
return err
}
r.Stats().IP.PacketsSent.Increment()
- payload.TrimFront(newPayloadLength)
+ pkt.Data.TrimFront(newPayloadLength)
} else {
- // The fragment is too small to fit all of hdr.
- startOfHdr := hdr
- startOfHdr.TrimBack(hdr.UsedLength() - outerMTU)
+ // The fragment is too small to fit all of pkt.Header.
+ startOfHdr := pkt.Header
+ startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU)
emptyVV := buffer.NewVectorisedView(0, []buffer.View{})
- if err := e.linkEP.WritePacket(r, gso, startOfHdr, emptyVV, ProtocolNumber); err != nil {
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{
+ Header: startOfHdr,
+ Data: emptyVV,
+ NetworkHeader: buffer.View(h),
+ }); err != nil {
return err
}
r.Stats().IP.PacketsSent.Increment()
- // Add the unused bytes of hdr into the payload that remains to be sent.
- restOfHdr := hdr.View()[outerMTU:]
+ // Add the unused bytes of pkt.Header into the pkt.Data
+ // that remains to be sent.
+ restOfHdr := pkt.Header.View()[outerMTU:]
tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)})
- tmp.Append(payload)
- payload = tmp
+ tmp.Append(pkt.Data)
+ pkt.Data = tmp
}
}
return nil
}
-// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error {
+func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 {
ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- length := uint16(hdr.UsedLength() + payload.Size())
+ length := uint16(hdr.UsedLength() + payloadSize)
id := uint32(0)
if length > header.IPv4MaximumHeaderSize+8 {
// Packets of 68 bytes or less are required by RFC 791 to not be
@@ -219,41 +234,71 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
DstAddr: r.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
+ return ip
+}
+
+// WritePacket writes a packet to the given destination address and protocol.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+ ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
+ pkt.NetworkHeader = buffer.View(ip)
if loop&stack.PacketLoop != 0 {
- views := make([]buffer.View, 1, 1+len(payload.Views()))
- views[0] = hdr.View()
- views = append(views, payload.Views()...)
- vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
+ // The inbound path expects the network header to still be in
+ // the PacketBuffer's Data field.
+ views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
+ views[0] = pkt.Header.View()
+ views = append(views, pkt.Data.Views()...)
loopedR := r.MakeLoopedRoute()
- e.HandlePacket(&loopedR, vv)
+
+ e.HandlePacket(&loopedR, tcpip.PacketBuffer{
+ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
+ })
+
loopedR.Release()
}
if loop&stack.PacketOut == 0 {
return nil
}
- if hdr.UsedLength()+payload.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
- return e.writePacketFragments(r, gso, hdr, payload, int(e.linkEP.MTU()))
+ if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
+ return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt)
}
- if err := e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber); err != nil {
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
return err
}
r.Stats().IP.PacketsSent.Increment()
return nil
}
+// WritePackets implements stack.NetworkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
+ if loop&stack.PacketLoop != 0 {
+ panic("multiple packets in local loop")
+ }
+ if loop&stack.PacketOut == 0 {
+ return len(pkts), nil
+ }
+
+ for i := range pkts {
+ ip := e.addIPHeader(r, &pkts[i].Header, pkts[i].DataSize, params)
+ pkts[i].NetworkHeader = buffer.View(ip)
+ }
+ n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
+}
+
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
// The packet already has an IP header, but there are a few required
// checks.
- ip := header.IPv4(payload.First())
- if !ip.IsValid(payload.Size()) {
+ ip := header.IPv4(pkt.Data.First())
+ if !ip.IsValid(pkt.Data.Size()) {
return tcpip.ErrInvalidOptionValue
}
// Always set the total length.
- ip.SetTotalLength(uint16(payload.Size()))
+ ip.SetTotalLength(uint16(pkt.Data.Size()))
// Set the source address when zero.
if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) {
@@ -267,7 +312,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect
// Set the packet ID when zero.
if ip.ID() == 0 {
id := uint32(0)
- if payload.Size() > header.IPv4MaximumHeaderSize+8 {
+ if pkt.Data.Size() > header.IPv4MaximumHeaderSize+8 {
// Packets of 68 bytes or less are required by RFC 791 to not be
// fragmented, so we only assign ids to larger packets.
id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)
@@ -280,35 +325,39 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect
ip.SetChecksum(^ip.CalculateChecksum())
if loop&stack.PacketLoop != 0 {
- e.HandlePacket(r, payload)
+ e.HandlePacket(r, pkt.Clone())
}
if loop&stack.PacketOut == 0 {
return nil
}
- hdr := buffer.NewPrependableFromView(payload.ToView())
r.Stats().IP.PacketsSent.Increment()
- return e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
+
+ ip = ip[:ip.HeaderLength()]
+ pkt.Header = buffer.NewPrependableFromView(buffer.View(ip))
+ pkt.Data.TrimFront(int(ip.HeaderLength()))
+ return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
- headerView := vv.First()
+func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
+ headerView := pkt.Data.First()
h := header.IPv4(headerView)
- if !h.IsValid(vv.Size()) {
+ if !h.IsValid(pkt.Data.Size()) {
r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
+ pkt.NetworkHeader = headerView[:h.HeaderLength()]
hlen := int(h.HeaderLength())
tlen := int(h.TotalLength())
- vv.TrimFront(hlen)
- vv.CapLength(tlen - hlen)
+ pkt.Data.TrimFront(hlen)
+ pkt.Data.CapLength(tlen - hlen)
more := (h.Flags() & header.IPv4FlagMoreFragments) != 0
if more || h.FragmentOffset() != 0 {
- if vv.Size() == 0 {
+ if pkt.Data.Size() == 0 {
// Drop the packet as it's marked as a fragment but has
// no payload.
r.Stats().IP.MalformedPacketsReceived.Increment()
@@ -316,10 +365,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
return
}
// The packet is a fragment, let's try to reassemble it.
- last := h.FragmentOffset() + uint16(vv.Size()) - 1
+ last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1
// Drop the packet if the fragmentOffset is incorrect. i.e the
- // combination of fragmentOffset and vv.size() causes a wrap
- // around resulting in last being less than the offset.
+ // combination of fragmentOffset and pkt.Data.size() causes a
+ // wrap around resulting in last being less than the offset.
if last < h.FragmentOffset() {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
@@ -327,7 +376,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
}
var ready bool
var err error
- vv, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv)
+ pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, pkt.Data)
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
@@ -340,11 +389,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
p := h.TransportProtocol()
if p == header.ICMPv4ProtocolNumber {
headerView.CapLength(hlen)
- e.handleICMP(r, headerView, vv)
+ e.handleICMP(r, pkt)
return
}
r.Stats().IP.PacketsDelivered.Increment()
- e.dispatcher.DeliverTransportPacket(r, p, headerView, vv)
+ e.dispatcher.DeliverTransportPacket(r, p, pkt)
}
// Close cleans up resources associated with the endpoint.
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 85ab0e3bc..e900f1b45 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -47,10 +47,6 @@ func TestExcludeBroadcast(t *testing.T) {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Any); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
NIC: 1,
@@ -117,12 +113,12 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.
// comparePayloads compared the contents of all the packets against the contents
// of the source packet.
-func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packetInfo, mtu uint32) {
+func compareFragments(t *testing.T, packets []tcpip.PacketBuffer, sourcePacketInfo tcpip.PacketBuffer, mtu uint32) {
t.Helper()
// Make a complete array of the sourcePacketInfo packet.
source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize])
source = append(source, sourcePacketInfo.Header.View()...)
- source = append(source, sourcePacketInfo.Payload.ToView()...)
+ source = append(source, sourcePacketInfo.Data.ToView()...)
// Make a copy of the IP header, which will be modified in some fields to make
// an expected header.
@@ -136,7 +132,7 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe
for i, packet := range packets {
// Confirm that the packet is valid.
allBytes := packet.Header.View().ToVectorisedView()
- allBytes.Append(packet.Payload)
+ allBytes.Append(packet.Data)
ip := header.IPv4(allBytes.ToView())
if !ip.IsValid(len(ip)) {
t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip))
@@ -177,7 +173,7 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe
type errorChannel struct {
*channel.Endpoint
- Ch chan packetInfo
+ Ch chan tcpip.PacketBuffer
packetCollectorErrors []*tcpip.Error
}
@@ -187,17 +183,11 @@ type errorChannel struct {
func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
return &errorChannel{
Endpoint: channel.New(size, mtu, linkAddr),
- Ch: make(chan packetInfo, size),
+ Ch: make(chan tcpip.PacketBuffer, size),
packetCollectorErrors: packetCollectorErrors,
}
}
-// packetInfo holds all the information about an outbound packet.
-type packetInfo struct {
- Header buffer.Prependable
- Payload buffer.VectorisedView
-}
-
// Drain removes all outbound packets from the channel and counts them.
func (e *errorChannel) Drain() int {
c := 0
@@ -212,14 +202,9 @@ func (e *errorChannel) Drain() int {
}
// WritePacket stores outbound packets into the channel.
-func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
- p := packetInfo{
- Header: hdr,
- Payload: payload,
- }
-
+func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
select {
- case e.Ch <- p:
+ case e.Ch <- pkt:
default:
}
@@ -296,18 +281,21 @@ func TestFragmentation(t *testing.T) {
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes)
- source := packetInfo{
+ source := tcpip.PacketBuffer{
Header: hdr,
// Save the source payload because WritePacket will modify it.
- Payload: payload.Clone([]buffer.View{}),
+ Data: payload.Clone(nil),
}
c := buildContext(t, nil, ft.mtu)
- err := c.Route.WritePacket(ft.gso, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS})
+ err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: payload,
+ })
if err != nil {
t.Errorf("err got %v, want %v", err, nil)
}
- var results []packetInfo
+ var results []tcpip.PacketBuffer
L:
for {
select {
@@ -349,7 +337,10 @@ func TestFragmentationErrors(t *testing.T) {
t.Run(ft.description, func(t *testing.T) {
hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
- err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS})
+ err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: payload,
+ })
for i := 0; i < len(ft.packetCollectorErrors)-1; i++ {
if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want {
t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want)
@@ -455,7 +446,7 @@ func TestInvalidFragments(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- const nicid tcpip.NICID = 42
+ const nicID tcpip.NICID = 42
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{
ipv4.NewProtocol(),
@@ -465,10 +456,12 @@ func TestInvalidFragments(t *testing.T) {
var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30})
var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31})
ep := channel.New(10, 1500, linkAddr)
- s.CreateNIC(nicid, sniffer.New(ep))
+ s.CreateNIC(nicID, sniffer.New(ep))
for _, pkt := range tc.packets {
- ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}))
+ ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, tcpip.PacketBuffer{
+ Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}),
+ })
}
if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want {
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index f06622a8b..e4e273460 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -10,9 +10,7 @@ go_library(
"ipv6.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv6",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 7b638e9d0..1c3410618 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -21,21 +21,12 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
-const (
- // ndpHopLimit is the expected IP hop limit value of 255 for received
- // NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1,
- // 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently
- // drop the NDP packet. All outgoing NDP packets must use this value for
- // its IP hop limit field.
- ndpHopLimit = 255
-)
-
// handleControl handles the case when an ICMP packet contains the headers of
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
// packet.
-func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
- h := header.IPv6(vv.First())
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+ h := header.IPv6(pkt.Data.First())
// We don't use IsValid() here because ICMP only requires that up to
// 1280 bytes of the original packet be included. So it's likely that it
@@ -49,10 +40,10 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.
// Skip the IP header, then handle the fragmentation header if there
// is one.
- vv.TrimFront(header.IPv6MinimumSize)
+ pkt.Data.TrimFront(header.IPv6MinimumSize)
p := h.TransportProtocol()
if p == header.IPv6FragmentHeader {
- f := header.IPv6Fragment(vv.First())
+ f := header.IPv6Fragment(pkt.Data.First())
if !f.IsValid() || f.FragmentOffset() != 0 {
// We can't handle fragments that aren't at offset 0
// because they don't have the transport headers.
@@ -61,19 +52,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.
// Skip fragmentation header and find out the actual protocol
// number.
- vv.TrimFront(header.IPv6FragmentHeaderSize)
+ pkt.Data.TrimFront(header.IPv6FragmentHeaderSize)
p = f.TransportProtocol()
}
// Deliver the control packet to the transport endpoint.
- e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv)
+ e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
-func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
+func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.PacketBuffer) {
stats := r.Stats().ICMP
sent := stats.V6PacketsSent
received := stats.V6PacketsReceived
- v := vv.First()
+ v := pkt.Data.First()
if len(v) < header.ICMPv6MinimumSize {
received.Invalid.Increment()
return
@@ -81,16 +72,34 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
h := header.ICMPv6(v)
iph := header.IPv6(netHeader)
+ // Validate ICMPv6 checksum before processing the packet.
+ //
+ // Only the first view in vv is accounted for by h. To account for the
+ // rest of vv, a shallow copy is made and the first view is removed.
+ // This copy is used as extra payload during the checksum calculation.
+ payload := pkt.Data
+ payload.RemoveFirst()
+ if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want {
+ received.Invalid.Increment()
+ return
+ }
+
// As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and
// 8.1, nodes MUST silently drop NDP packets where the Hop Limit field
- // in the IPv6 header is not set to 255.
+ // in the IPv6 header is not set to 255, or the ICMPv6 Code field is not
+ // set to 0.
switch h.Type() {
case header.ICMPv6NeighborSolicit,
header.ICMPv6NeighborAdvert,
header.ICMPv6RouterSolicit,
header.ICMPv6RouterAdvert,
header.ICMPv6RedirectMsg:
- if iph.HopLimit() != ndpHopLimit {
+ if iph.HopLimit() != header.NDPHopLimit {
+ received.Invalid.Increment()
+ return
+ }
+
+ if h.Code() != 0 {
received.Invalid.Increment()
return
}
@@ -104,9 +113,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
received.Invalid.Increment()
return
}
- vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
+ pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
mtu := h.MTU()
- e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
+ e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt)
case header.ICMPv6DstUnreachable:
received.DstUnreachable.Increment()
@@ -114,10 +123,10 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
received.Invalid.Increment()
return
}
- vv.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
+ pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
switch h.Code() {
case header.ICMPv6PortUnreachable:
- e.handleControl(stack.ControlPortUnreachable, 0, vv)
+ e.handleControl(stack.ControlPortUnreachable, 0, pkt)
}
case header.ICMPv6NeighborSolicit:
@@ -171,7 +180,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
// rxNICID so the packet is processed as defined in RFC 4861,
// as per RFC 4862 section 5.4.3.
- if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 {
+ if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 {
// We don't have a useful answer; the best we can do is ignore the request.
return
}
@@ -180,9 +189,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress[:]),
}
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()))
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
- pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ packet.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(packet.NDPPayload())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(targetAddr)
@@ -200,7 +209,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
r := r.Clone()
defer r.Release()
r.LocalAddress = targetAddr
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
// TODO(tamird/ghanan): there exists an explicit NDP option that is
// used to update the neighbor table with link addresses for a
@@ -209,7 +218,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
//
// Furthermore, the entirety of NDP handling here seems to be
// contradicted by RFC 4861.
- e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
+ e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress)
// RFC 4861 Neighbor Discovery for IP version 6 (IPv6)
//
@@ -217,7 +226,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
//
// The IP Hop Limit field has a value of 255, i.e., the packet
// could not possibly have been forwarded by a router.
- if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ndpHopLimit, TOS: stack.DefaultTOS}); err != nil {
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ }); err != nil {
sent.Dropped.Increment()
return
}
@@ -255,19 +266,19 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
return
}
- // At this point we know that the targetAddress is not tentaive
+ // At this point we know that the targetAddress is not tentative
// on rxNICID. However, targetAddr may still be assigned to
// rxNICID but not tentative (it could be permanent). Such a
// scenario is beyond the scope of RFC 4862. As such, we simply
// ignore such a scenario for now and proceed as normal.
//
- // TODO(b/140896005): Handle the scenario described above
- // (inform the netstack integration that a duplicate address was
- // was detected)
+ // TODO(b/143147598): Handle the scenario described above. Also
+ // inform the netstack integration that a duplicate address was
+ // detected outside of DAD.
- e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress)
+ e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, r.RemoteLinkAddress)
if targetAddr != r.RemoteAddress {
- e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
+ e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress)
}
case header.ICMPv6EchoRequest:
@@ -276,13 +287,16 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
received.Invalid.Increment()
return
}
- vv.TrimFront(header.ICMPv6EchoMinimumSize)
+ pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize)
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
- copy(pkt, h)
- pkt.SetType(header.ICMPv6EchoReply)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, vv))
- if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil {
+ packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
+ copy(packet, h)
+ packet.SetType(header.ICMPv6EchoReply)
+ packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data))
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: pkt.Data,
+ }); err != nil {
sent.Dropped.Increment()
return
}
@@ -294,7 +308,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
received.Invalid.Increment()
return
}
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, netHeader, vv)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, pkt)
case header.ICMPv6TimeExceeded:
received.TimeExceeded.Increment()
@@ -306,8 +320,51 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
received.RouterSolicit.Increment()
case header.ICMPv6RouterAdvert:
+ routerAddr := iph.SourceAddress()
+
+ //
+ // Validate the RA as per RFC 4861 section 6.1.2.
+ //
+
+ // Is the IP Source Address a link-local address?
+ if !header.IsV6LinkLocalAddress(routerAddr) {
+ // ...No, silently drop the packet.
+ received.Invalid.Increment()
+ return
+ }
+
+ p := h.NDPPayload()
+
+ // Is the NDP payload of sufficient size to hold a Router
+ // Advertisement?
+ if len(p) < header.NDPRAMinimumSize {
+ // ...No, silently drop the packet.
+ received.Invalid.Increment()
+ return
+ }
+
+ ra := header.NDPRouterAdvert(p)
+ opts := ra.Options()
+
+ // Are options valid as per the wire format?
+ if _, err := opts.Iter(true); err != nil {
+ // ...No, silently drop the packet.
+ received.Invalid.Increment()
+ return
+ }
+
+ //
+ // At this point, we have a valid Router Advertisement, as far
+ // as RFC 4861 section 6.1.2 is concerned.
+ //
+
received.RouterAdvert.Increment()
+ // Tell the NIC to handle the RA.
+ stack := r.Stack()
+ rxNICID := r.NICID()
+ stack.HandleNDPRA(rxNICID, routerAddr, ra)
+
case header.ICMPv6RedirectMsg:
received.RedirectMsg.Increment()
@@ -359,13 +416,15 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
ip.Encode(&header.IPv6Fields{
PayloadLength: length,
NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: ndpHopLimit,
+ HopLimit: header.NDPHopLimit,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
// TODO(stijlist): count this in ICMP stats.
- return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
+ return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{
+ Header: hdr,
+ })
}
// ResolveStaticAddress implements stack.LinkAddressResolver.
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index dd3c4d7c4..335f634d5 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -30,7 +30,7 @@ import (
)
const (
- linkAddr0 = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06")
+ linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
)
@@ -55,7 +55,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
-func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, tcpip.PacketBuffer) *tcpip.Error {
return nil
}
@@ -65,7 +65,7 @@ type stubDispatcher struct {
stack.TransportDispatcher
}
-func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, buffer.View, buffer.VectorisedView) {
+func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, tcpip.PacketBuffer) {
}
type stubLinkAddressCache struct {
@@ -131,7 +131,7 @@ func TestICMPCounts(t *testing.T) {
{header.ICMPv6EchoRequest, header.ICMPv6EchoMinimumSize},
{header.ICMPv6EchoReply, header.ICMPv6EchoMinimumSize},
{header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize},
- {header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize},
+ {header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize},
{header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize},
{header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize},
{header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize},
@@ -143,11 +143,13 @@ func TestICMPCounts(t *testing.T) {
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: ndpHopLimit,
+ HopLimit: header.NDPHopLimit,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- ep.HandlePacket(&r, hdr.View().ToVectorisedView())
+ ep.HandlePacket(&r, tcpip.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
}
for _, typ := range types {
@@ -274,20 +276,22 @@ type routeArgs struct {
func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) {
t.Helper()
- pkt := <-args.src.C
+ pi := <-args.src.C
{
- views := []buffer.View{pkt.Header, pkt.Payload}
- size := len(pkt.Header) + len(pkt.Payload)
+ views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()}
+ size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size()
vv := buffer.NewVectorisedView(size, views)
- args.dst.InjectLinkAddr(pkt.Proto, args.dst.LinkAddress(), vv)
+ args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), tcpip.PacketBuffer{
+ Data: vv,
+ })
}
- if pkt.Proto != ProtocolNumber {
- t.Errorf("unexpected protocol number %d", pkt.Proto)
+ if pi.Proto != ProtocolNumber {
+ t.Errorf("unexpected protocol number %d", pi.Proto)
return
}
- ipv6 := header.IPv6(pkt.Header)
+ ipv6 := header.IPv6(pi.Pkt.Header.View())
transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader())
if transProto != header.ICMPv6ProtocolNumber {
t.Errorf("unexpected transport protocol number %d", transProto)
@@ -359,3 +363,537 @@ func TestLinkResolution(t *testing.T) {
routeICMPv6Packet(t, args, nil)
}
}
+
+func TestICMPChecksumValidationSimple(t *testing.T) {
+ types := []struct {
+ name string
+ typ header.ICMPv6Type
+ size int
+ statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ }{
+ {
+ "DstUnreachable",
+ header.ICMPv6DstUnreachable,
+ header.ICMPv6DstUnreachableMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.DstUnreachable
+ },
+ },
+ {
+ "PacketTooBig",
+ header.ICMPv6PacketTooBig,
+ header.ICMPv6PacketTooBigMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.PacketTooBig
+ },
+ },
+ {
+ "TimeExceeded",
+ header.ICMPv6TimeExceeded,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.TimeExceeded
+ },
+ },
+ {
+ "ParamProblem",
+ header.ICMPv6ParamProblem,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.ParamProblem
+ },
+ },
+ {
+ "EchoRequest",
+ header.ICMPv6EchoRequest,
+ header.ICMPv6EchoMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoRequest
+ },
+ },
+ {
+ "EchoReply",
+ header.ICMPv6EchoReply,
+ header.ICMPv6EchoMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoReply
+ },
+ },
+ {
+ "RouterSolicit",
+ header.ICMPv6RouterSolicit,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterSolicit
+ },
+ },
+ {
+ "RouterAdvert",
+ header.ICMPv6RouterAdvert,
+ header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterAdvert
+ },
+ },
+ {
+ "NeighborSolicit",
+ header.ICMPv6NeighborSolicit,
+ header.ICMPv6NeighborSolicitMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborSolicit
+ },
+ },
+ {
+ "NeighborAdvert",
+ header.ICMPv6NeighborAdvert,
+ header.ICMPv6NeighborAdvertSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborAdvert
+ },
+ },
+ {
+ "RedirectMsg",
+ header.ICMPv6RedirectMsg,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RedirectMsg
+ },
+ },
+ }
+
+ for _, typ := range types {
+ t.Run(typ.name, func(t *testing.T) {
+ e := channel.New(10, 1280, linkAddr0)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }},
+ )
+ }
+
+ handleIPv6Payload := func(typ header.ICMPv6Type, size int, checksum bool) {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + size)
+ pkt := header.ICMPv6(hdr.Prepend(size))
+ pkt.SetType(typ)
+ if checksum {
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ }
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(size),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+ e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ }
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ typStat := typ.statCounter(stats)
+
+ // Initial stat counts should be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Without setting checksum, the incoming packet should
+ // be invalid.
+ handleIPv6Payload(typ.typ, typ.size, false)
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ // Rx count of type typ.typ should not have increased.
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // When checksum is set, it should be received.
+ handleIPv6Payload(typ.typ, typ.size, true)
+ if got := typStat.Value(); got != 1 {
+ t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ }
+ // Invalid count should not have increased again.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ })
+ }
+}
+
+func TestICMPChecksumValidationWithPayload(t *testing.T) {
+ const simpleBodySize = 64
+ simpleBody := func(view buffer.View) {
+ for i := 0; i < simpleBodySize; i++ {
+ view[i] = uint8(i)
+ }
+ }
+
+ const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize
+ errorICMPBody := func(view buffer.View) {
+ ip := header.IPv6(view)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: simpleBodySize,
+ NextHeader: 10,
+ HopLimit: 20,
+ SrcAddr: lladdr0,
+ DstAddr: lladdr1,
+ })
+ simpleBody(view[header.IPv6MinimumSize:])
+ }
+
+ types := []struct {
+ name string
+ typ header.ICMPv6Type
+ size int
+ statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ payloadSize int
+ payload func(buffer.View)
+ }{
+ {
+ "DstUnreachable",
+ header.ICMPv6DstUnreachable,
+ header.ICMPv6DstUnreachableMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.DstUnreachable
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "PacketTooBig",
+ header.ICMPv6PacketTooBig,
+ header.ICMPv6PacketTooBigMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.PacketTooBig
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "TimeExceeded",
+ header.ICMPv6TimeExceeded,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.TimeExceeded
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "ParamProblem",
+ header.ICMPv6ParamProblem,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.ParamProblem
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "EchoRequest",
+ header.ICMPv6EchoRequest,
+ header.ICMPv6EchoMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoRequest
+ },
+ simpleBodySize,
+ simpleBody,
+ },
+ {
+ "EchoReply",
+ header.ICMPv6EchoReply,
+ header.ICMPv6EchoMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoReply
+ },
+ simpleBodySize,
+ simpleBody,
+ },
+ }
+
+ for _, typ := range types {
+ t.Run(typ.name, func(t *testing.T) {
+ e := channel.New(10, 1280, linkAddr0)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }},
+ )
+ }
+
+ handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) {
+ icmpSize := size + payloadSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
+ pkt := header.ICMPv6(hdr.Prepend(icmpSize))
+ pkt.SetType(typ)
+ payloadFn(pkt.Payload())
+
+ if checksum {
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ }
+
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(icmpSize),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+ e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ }
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ typStat := typ.statCounter(stats)
+
+ // Initial stat counts should be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Without setting checksum, the incoming packet should
+ // be invalid.
+ handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false)
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ // Rx count of type typ.typ should not have increased.
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // When checksum is set, it should be received.
+ handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
+ if got := typStat.Value(); got != 1 {
+ t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ }
+ // Invalid count should not have increased again.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ })
+ }
+}
+
+func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
+ const simpleBodySize = 64
+ simpleBody := func(view buffer.View) {
+ for i := 0; i < simpleBodySize; i++ {
+ view[i] = uint8(i)
+ }
+ }
+
+ const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize
+ errorICMPBody := func(view buffer.View) {
+ ip := header.IPv6(view)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: simpleBodySize,
+ NextHeader: 10,
+ HopLimit: 20,
+ SrcAddr: lladdr0,
+ DstAddr: lladdr1,
+ })
+ simpleBody(view[header.IPv6MinimumSize:])
+ }
+
+ types := []struct {
+ name string
+ typ header.ICMPv6Type
+ size int
+ statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ payloadSize int
+ payload func(buffer.View)
+ }{
+ {
+ "DstUnreachable",
+ header.ICMPv6DstUnreachable,
+ header.ICMPv6DstUnreachableMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.DstUnreachable
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "PacketTooBig",
+ header.ICMPv6PacketTooBig,
+ header.ICMPv6PacketTooBigMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.PacketTooBig
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "TimeExceeded",
+ header.ICMPv6TimeExceeded,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.TimeExceeded
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "ParamProblem",
+ header.ICMPv6ParamProblem,
+ header.ICMPv6MinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.ParamProblem
+ },
+ errorICMPBodySize,
+ errorICMPBody,
+ },
+ {
+ "EchoRequest",
+ header.ICMPv6EchoRequest,
+ header.ICMPv6EchoMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoRequest
+ },
+ simpleBodySize,
+ simpleBody,
+ },
+ {
+ "EchoReply",
+ header.ICMPv6EchoReply,
+ header.ICMPv6EchoMinimumSize,
+ func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.EchoReply
+ },
+ simpleBodySize,
+ simpleBody,
+ },
+ }
+
+ for _, typ := range types {
+ t.Run(typ.name, func(t *testing.T) {
+ e := channel.New(10, 1280, linkAddr0)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }},
+ )
+ }
+
+ handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + size)
+ pkt := header.ICMPv6(hdr.Prepend(size))
+ pkt.SetType(typ)
+
+ payload := buffer.NewView(payloadSize)
+ payloadFn(payload)
+
+ if checksum {
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView()))
+ }
+
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(size + payloadSize),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+ e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}),
+ })
+ }
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ typStat := typ.statCounter(stats)
+
+ // Initial stat counts should be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Without setting checksum, the incoming packet should
+ // be invalid.
+ handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false)
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ // Rx count of type typ.typ should not have increased.
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // When checksum is set, it should be received.
+ handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true)
+ if got := typStat.Value(); got != 1 {
+ t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ }
+ // Invalid count should not have increased again.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index cd1e34085..e13f1fabf 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -43,7 +43,7 @@ const (
)
type endpoint struct {
- nicid tcpip.NICID
+ nicID tcpip.NICID
id stack.NetworkEndpointID
prefixLen int
linkEP stack.LinkEndpoint
@@ -65,7 +65,7 @@ func (e *endpoint) MTU() uint32 {
// NICID returns the ID of the NIC this endpoint belongs to.
func (e *endpoint) NICID() tcpip.NICID {
- return e.nicid
+ return e.nicID
}
// ID returns the ipv6 endpoint ID.
@@ -97,9 +97,8 @@ func (e *endpoint) GSOMaxSize() uint32 {
return 0
}
-// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error {
- length := uint16(hdr.UsedLength() + payload.Size())
+func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv6 {
+ length := uint16(hdr.UsedLength() + payloadSize)
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: length,
@@ -109,14 +108,26 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
+ return ip
+}
+
+// WritePacket writes a packet to the given destination address and protocol.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+ ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
+ pkt.NetworkHeader = buffer.View(ip)
if loop&stack.PacketLoop != 0 {
- views := make([]buffer.View, 1, 1+len(payload.Views()))
- views[0] = hdr.View()
- views = append(views, payload.Views()...)
- vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
+ // The inbound path expects the network header to still be in
+ // the PacketBuffer's Data field.
+ views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
+ views[0] = pkt.Header.View()
+ views = append(views, pkt.Data.Views()...)
loopedR := r.MakeLoopedRoute()
- e.HandlePacket(&loopedR, vv)
+
+ e.HandlePacket(&loopedR, tcpip.PacketBuffer{
+ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
+ })
+
loopedR.Release()
}
if loop&stack.PacketOut == 0 {
@@ -124,36 +135,58 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
}
r.Stats().IP.PacketsSent.Increment()
- return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber)
+ return e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt)
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
+ if loop&stack.PacketLoop != 0 {
+ panic("not implemented")
+ }
+ if loop&stack.PacketOut == 0 {
+ return len(pkts), nil
+ }
+
+ for i := range pkts {
+ hdr := &pkts[i].Header
+ size := pkts[i].DataSize
+ ip := e.addIPHeader(r, hdr, size, params)
+ pkts[i].NetworkHeader = buffer.View(ip)
+ }
+
+ n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
}
// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
// supported by IPv6.
-func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
- // TODO(b/119580726): Support IPv6 header-included packets.
+func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
+ // TODO(b/146666412): Support IPv6 header-included packets.
return tcpip.ErrNotSupported
}
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
- headerView := vv.First()
+func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
+ headerView := pkt.Data.First()
h := header.IPv6(headerView)
- if !h.IsValid(vv.Size()) {
+ if !h.IsValid(pkt.Data.Size()) {
return
}
- vv.TrimFront(header.IPv6MinimumSize)
- vv.CapLength(int(h.PayloadLength()))
+ pkt.NetworkHeader = headerView[:header.IPv6MinimumSize]
+ pkt.Data.TrimFront(header.IPv6MinimumSize)
+ pkt.Data.CapLength(int(h.PayloadLength()))
p := h.TransportProtocol()
if p == header.ICMPv6ProtocolNumber {
- e.handleICMP(r, headerView, vv)
+ e.handleICMP(r, headerView, pkt)
return
}
r.Stats().IP.PacketsDelivered.Increment()
- e.dispatcher.DeliverTransportPacket(r, p, headerView, vv)
+ e.dispatcher.DeliverTransportPacket(r, p, pkt)
}
// Close cleans up resources associated with the endpoint.
@@ -188,9 +221,9 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
}
// NewEndpoint creates a new ipv6 endpoint.
-func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
return &endpoint{
- nicid: nicid,
+ nicID: nicID,
id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
prefixLen: addrWithPrefix.PrefixLen,
linkEP: linkEP,
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index deaa9b7f3..1cbfa7278 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -55,7 +55,9 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
DstAddr: dst,
})
- e.Inject(ProtocolNumber, hdr.View().ToVectorisedView())
+ e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
stats := s.Stats().ICMP.V6PacketsReceived
@@ -111,7 +113,9 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
DstAddr: dst,
})
- e.Inject(ProtocolNumber, hdr.View().ToVectorisedView())
+ e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
stat := s.Stats().UDP.PacketsReceived
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index e30791fe3..0dbce14a0 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
)
@@ -97,7 +98,9 @@ func TestHopLimitValidation(t *testing.T) {
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- ep.HandlePacket(r, hdr.View().ToVectorisedView())
+ ep.HandlePacket(r, tcpip.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
}
types := []struct {
@@ -109,7 +112,7 @@ func TestHopLimitValidation(t *testing.T) {
{"RouterSolicit", header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.RouterSolicit
}},
- {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.RouterAdvert
}},
{"NeighborSolicit", header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
@@ -150,7 +153,7 @@ func TestHopLimitValidation(t *testing.T) {
// Receive the NDP packet with an invalid hop limit
// value.
- handleIPv6Payload(hdr, ndpHopLimit-1, ep, &r)
+ handleIPv6Payload(hdr, header.NDPHopLimit-1, ep, &r)
// Invalid count should have increased.
if got := invalid.Value(); got != 1 {
@@ -164,7 +167,7 @@ func TestHopLimitValidation(t *testing.T) {
}
// Receive the NDP packet with a valid hop limit value.
- handleIPv6Payload(hdr, ndpHopLimit, ep, &r)
+ handleIPv6Payload(hdr, header.NDPHopLimit, ep, &r)
// Rx count of NDP packet of type typ.typ should have
// increased.
@@ -179,3 +182,191 @@ func TestHopLimitValidation(t *testing.T) {
})
}
}
+
+// TestRouterAdvertValidation tests that when the NIC is configured to handle
+// NDP Router Advertisement packets, it validates the Router Advertisement
+// properly before handling them.
+func TestRouterAdvertValidation(t *testing.T) {
+ tests := []struct {
+ name string
+ src tcpip.Address
+ hopLimit uint8
+ code uint8
+ ndpPayload []byte
+ expectedSuccess bool
+ }{
+ {
+ "OK",
+ lladdr0,
+ 255,
+ 0,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ true,
+ },
+ {
+ "NonLinkLocalSourceAddr",
+ addr1,
+ 255,
+ 0,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ false,
+ },
+ {
+ "HopLimitNot255",
+ lladdr0,
+ 254,
+ 0,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ false,
+ },
+ {
+ "NonZeroCode",
+ lladdr0,
+ 255,
+ 1,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ false,
+ },
+ {
+ "NDPPayloadTooSmall",
+ lladdr0,
+ 255,
+ 0,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0,
+ },
+ false,
+ },
+ {
+ "OKWithOptions",
+ lladdr0,
+ 255,
+ 0,
+ []byte{
+ // RA payload
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+
+ // Option #1 (TargetLinkLayerAddress)
+ 2, 1, 0, 0, 0, 0, 0, 0,
+
+ // Option #2 (unrecognized)
+ 255, 1, 0, 0, 0, 0, 0, 0,
+
+ // Option #3 (PrefixInformation)
+ 3, 4, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ },
+ true,
+ },
+ {
+ "OptionWithZeroLength",
+ lladdr0,
+ 255,
+ 0,
+ []byte{
+ // RA payload
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+
+ // Option #1 (TargetLinkLayerAddress)
+ // Invalid as it has 0 length.
+ 2, 0, 0, 0, 0, 0, 0, 0,
+
+ // Option #2 (unrecognized)
+ 255, 1, 0, 0, 0, 0, 0, 0,
+
+ // Option #3 (PrefixInformation)
+ 3, 4, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ },
+ false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
+ pkt := header.ICMPv6(hdr.Prepend(icmpSize))
+ pkt.SetType(header.ICMPv6RouterAdvert)
+ pkt.SetCode(test.code)
+ copy(pkt.NDPPayload(), test.ndpPayload)
+ payloadLength := hdr.UsedLength()
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: test.hopLimit,
+ SrcAddr: test.src,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ rxRA := stats.RouterAdvert
+
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := rxRA.Value(); got != 0 {
+ t.Fatalf("got rxRA = %d, want = 0", got)
+ }
+
+ e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ if test.expectedSuccess {
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := rxRA.Value(); got != 1 {
+ t.Fatalf("got rxRA = %d, want = 1", got)
+ }
+
+ } else {
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ if got := rxRA.Value(); got != 0 {
+ t.Fatalf("got rxRA = %d, want = 0", got)
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/packet_buffer.go b/pkg/tcpip/packet_buffer.go
new file mode 100644
index 000000000..ab24372e7
--- /dev/null
+++ b/pkg/tcpip/packet_buffer.go
@@ -0,0 +1,67 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at //
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip
+
+import "gvisor.dev/gvisor/pkg/tcpip/buffer"
+
+// A PacketBuffer contains all the data of a network packet.
+//
+// As a PacketBuffer traverses up the stack, it may be necessary to pass it to
+// multiple endpoints. Clone() should be called in such cases so that
+// modifications to the Data field do not affect other copies.
+//
+// +stateify savable
+type PacketBuffer struct {
+ // Data holds the payload of the packet. For inbound packets, it also
+ // holds the headers, which are consumed as the packet moves up the
+ // stack. Headers are guaranteed not to be split across views.
+ //
+ // The bytes backing Data are immutable, but Data itself may be trimmed
+ // or otherwise modified.
+ Data buffer.VectorisedView
+
+ // DataOffset is used for GSO output. It is the offset into the Data
+ // field where the payload of this packet starts.
+ DataOffset int
+
+ // DataSize is used for GSO output. It is the size of this packet's
+ // payload.
+ DataSize int
+
+ // Header holds the headers of outbound packets. As a packet is passed
+ // down the stack, each layer adds to Header.
+ Header buffer.Prependable
+
+ // These fields are used by both inbound and outbound packets. They
+ // typically overlap with the Data and Header fields.
+ //
+ // The bytes backing these views are immutable. Each field may be nil
+ // if either it has not been set yet or no such header exists (e.g.
+ // packets sent via loopback may not have a link header).
+ //
+ // These fields may be Views into other slices (either Data or Header).
+ // SR dosen't support this, so deep copies are necessary in some cases.
+ LinkHeader buffer.View
+ NetworkHeader buffer.View
+ TransportHeader buffer.View
+}
+
+// Clone makes a copy of pk. It clones the Data field, which creates a new
+// VectorisedView but does not deep copy the underlying bytes.
+//
+// Clone also does not deep copy any of its other fields.
+func (pk PacketBuffer) Clone() PacketBuffer {
+ pk.Data = pk.Data.Clone(nil)
+ return pk
+}
diff --git a/pkg/tcpip/packet_buffer_state.go b/pkg/tcpip/packet_buffer_state.go
new file mode 100644
index 000000000..ad3cc24fa
--- /dev/null
+++ b/pkg/tcpip/packet_buffer_state.go
@@ -0,0 +1,27 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip
+
+import "gvisor.dev/gvisor/pkg/tcpip/buffer"
+
+// beforeSave is invoked by stateify.
+func (pk *PacketBuffer) beforeSave() {
+ // Non-Data fields may be slices of the Data field. This causes
+ // problems for SR, so during save we make each header independent.
+ pk.Header = pk.Header.DeepCopy()
+ pk.LinkHeader = append(buffer.View(nil), pk.LinkHeader...)
+ pk.NetworkHeader = append(buffer.View(nil), pk.NetworkHeader...)
+ pk.TransportHeader = append(buffer.View(nil), pk.TransportHeader...)
+}
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
index 11efb4e44..e156b01f6 100644
--- a/pkg/tcpip/ports/BUILD
+++ b/pkg/tcpip/ports/BUILD
@@ -1,5 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -7,7 +7,7 @@ go_library(
name = "ports",
srcs = ["ports.go"],
importpath = "gvisor.dev/gvisor/pkg/tcpip/ports",
- visibility = ["//:sandbox"],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
],
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index 30cea8996..6c5e19e8f 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -41,6 +41,30 @@ type portDescriptor struct {
port uint16
}
+// Flags represents the type of port reservation.
+//
+// +stateify savable
+type Flags struct {
+ // MostRecent represents UDP SO_REUSEADDR.
+ MostRecent bool
+
+ // LoadBalanced indicates SO_REUSEPORT.
+ //
+ // LoadBalanced takes precidence over MostRecent.
+ LoadBalanced bool
+}
+
+func (f Flags) bits() reuseFlag {
+ var rf reuseFlag
+ if f.MostRecent {
+ rf |= mostRecentFlag
+ }
+ if f.LoadBalanced {
+ rf |= loadBalancedFlag
+ }
+ return rf
+}
+
// PortManager manages allocating, reserving and releasing ports.
type PortManager struct {
mu sync.RWMutex
@@ -54,9 +78,59 @@ type PortManager struct {
hint uint32
}
+type reuseFlag int
+
+const (
+ mostRecentFlag reuseFlag = 1 << iota
+ loadBalancedFlag
+ nextFlag
+
+ flagMask = nextFlag - 1
+)
+
type portNode struct {
- reuse bool
- refs int
+ // refs stores the count for each possible flag combination.
+ refs [nextFlag]int
+}
+
+func (p portNode) totalRefs() int {
+ var total int
+ for _, r := range p.refs {
+ total += r
+ }
+ return total
+}
+
+// flagRefs returns the number of references with all specified flags.
+func (p portNode) flagRefs(flags reuseFlag) int {
+ var total int
+ for i, r := range p.refs {
+ if reuseFlag(i)&flags == flags {
+ total += r
+ }
+ }
+ return total
+}
+
+// allRefsHave returns if all references have all specified flags.
+func (p portNode) allRefsHave(flags reuseFlag) bool {
+ for i, r := range p.refs {
+ if reuseFlag(i)&flags == flags && r > 0 {
+ return false
+ }
+ }
+ return true
+}
+
+// intersectionRefs returns the set of flags shared by all references.
+func (p portNode) intersectionRefs() reuseFlag {
+ intersection := flagMask
+ for i, r := range p.refs {
+ if r > 0 {
+ intersection &= reuseFlag(i)
+ }
+ }
+ return intersection
}
// deviceNode is never empty. When it has no elements, it is removed from the
@@ -66,30 +140,44 @@ type deviceNode map[tcpip.NICID]portNode
// isAvailable checks whether binding is possible by device. If not binding to a
// device, check against all portNodes. If binding to a specific device, check
// against the unspecified device and the provided device.
-func (d deviceNode) isAvailable(reuse bool, bindToDevice tcpip.NICID) bool {
+//
+// If either of the port reuse flags is enabled on any of the nodes, all nodes
+// sharing a port must share at least one reuse flag. This matches Linux's
+// behavior.
+func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool {
+ flagBits := flags.bits()
if bindToDevice == 0 {
// Trying to binding all devices.
- if !reuse {
+ if flagBits == 0 {
// Can't bind because the (addr,port) is already bound.
return false
}
+ intersection := flagMask
for _, p := range d {
- if !p.reuse {
- // Can't bind because the (addr,port) was previously bound without reuse.
+ i := p.intersectionRefs()
+ intersection &= i
+ if intersection&flagBits == 0 {
+ // Can't bind because the (addr,port) was
+ // previously bound without reuse.
return false
}
}
return true
}
+ intersection := flagMask
+
if p, ok := d[0]; ok {
- if !reuse || !p.reuse {
+ intersection = p.intersectionRefs()
+ if intersection&flagBits == 0 {
return false
}
}
if p, ok := d[bindToDevice]; ok {
- if !reuse || !p.reuse {
+ i := p.intersectionRefs()
+ intersection &= i
+ if intersection&flagBits == 0 {
return false
}
}
@@ -103,12 +191,12 @@ type bindAddresses map[tcpip.Address]deviceNode
// isAvailable checks whether an IP address is available to bind to. If the
// address is the "any" address, check all other addresses. Otherwise, just
// check against the "any" address and the provided address.
-func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice tcpip.NICID) bool {
+func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID) bool {
if addr == anyIPAddress {
// If binding to the "any" address then check that there are no conflicts
// with all addresses.
for _, d := range b {
- if !d.isAvailable(reuse, bindToDevice) {
+ if !d.isAvailable(flags, bindToDevice) {
return false
}
}
@@ -117,14 +205,14 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice
// Check that there is no conflict with the "any" address.
if d, ok := b[anyIPAddress]; ok {
- if !d.isAvailable(reuse, bindToDevice) {
+ if !d.isAvailable(flags, bindToDevice) {
return false
}
}
// Check that this is no conflict with the provided address.
if d, ok := b[addr]; ok {
- if !d.isAvailable(reuse, bindToDevice) {
+ if !d.isAvailable(flags, bindToDevice) {
return false
}
}
@@ -190,17 +278,17 @@ func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p ui
}
// IsPortAvailable tests if the given port is available on all given protocols.
-func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
+func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool {
s.mu.Lock()
defer s.mu.Unlock()
- return s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice)
+ return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice)
}
-func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
+func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool {
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if addrs, ok := s.allocatedPorts[desc]; ok {
- if !addrs.isAvailable(addr, reuse, bindToDevice) {
+ if !addrs.isAvailable(addr, flags, bindToDevice) {
return false
}
}
@@ -212,14 +300,14 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb
// reserved by another endpoint. If port is zero, ReservePort will search for
// an unreserved ephemeral port and reserve it, returning its value in the
// "port" return value.
-func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) {
+func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) {
s.mu.Lock()
defer s.mu.Unlock()
// If a port is specified, just try to reserve it for all network
// protocols.
if port != 0 {
- if !s.reserveSpecificPort(networks, transport, addr, port, reuse, bindToDevice) {
+ if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice) {
return 0, tcpip.ErrPortInUse
}
return port, nil
@@ -227,15 +315,16 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp
// A port wasn't specified, so try to find one.
return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
- return s.reserveSpecificPort(networks, transport, addr, p, reuse, bindToDevice), nil
+ return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice), nil
})
}
// reserveSpecificPort tries to reserve the given port on all given protocols.
-func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
- if !s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) {
+func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool {
+ if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) {
return false
}
+ flagBits := flags.bits()
// Reserve port on all network protocols.
for _, network := range networks {
@@ -250,12 +339,9 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
d = make(deviceNode)
m[addr] = d
}
- if n, ok := d[bindToDevice]; ok {
- n.refs++
- d[bindToDevice] = n
- } else {
- d[bindToDevice] = portNode{reuse: reuse, refs: 1}
- }
+ n := d[bindToDevice]
+ n.refs[flagBits]++
+ d[bindToDevice] = n
}
return true
@@ -263,10 +349,12 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
// ReleasePort releases the reservation on a port/IP combination so that it can
// be reserved by other endpoints.
-func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, bindToDevice tcpip.NICID) {
+func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) {
s.mu.Lock()
defer s.mu.Unlock()
+ flagBits := flags.bits()
+
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if m, ok := s.allocatedPorts[desc]; ok {
@@ -278,9 +366,9 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp
if !ok {
continue
}
- n.refs--
+ n.refs[flagBits]--
d[bindToDevice] = n
- if n.refs == 0 {
+ if n.refs == [nextFlag]int{} {
delete(d, bindToDevice)
}
if len(d) == 0 {
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index 19f4833fc..d6969d050 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -33,7 +33,7 @@ type portReserveTestAction struct {
port uint16
ip tcpip.Address
want *tcpip.Error
- reuse bool
+ flags Flags
release bool
device tcpip.NICID
}
@@ -50,7 +50,7 @@ func TestPortReservation(t *testing.T) {
{port: 80, ip: fakeIPAddress1, want: nil},
/* N.B. Order of tests matters! */
{port: 80, ip: anyIPAddress, want: tcpip.ErrPortInUse},
- {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true},
+ {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}},
},
},
{
@@ -61,7 +61,7 @@ func TestPortReservation(t *testing.T) {
/* release fakeIPAddress, but anyIPAddress is still inuse */
{port: 22, ip: fakeIPAddress, release: true},
{port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
- {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true},
+ {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}},
/* Release port 22 from any IP address, then try to reserve fake IP address on 22 */
{port: 22, ip: anyIPAddress, want: nil, release: true},
{port: 22, ip: fakeIPAddress, want: nil},
@@ -71,36 +71,36 @@ func TestPortReservation(t *testing.T) {
actions: []portReserveTestAction{
{port: 00, ip: fakeIPAddress, want: nil},
{port: 00, ip: fakeIPAddress, want: nil},
- {port: 00, ip: fakeIPAddress, reuse: true, want: nil},
+ {port: 00, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
},
}, {
tname: "bind to ip with reuseport",
actions: []portReserveTestAction{
- {port: 25, ip: fakeIPAddress, reuse: true, want: nil},
- {port: 25, ip: fakeIPAddress, reuse: true, want: nil},
+ {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 25, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse},
- {port: 25, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse},
+ {port: 25, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
+ {port: 25, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
- {port: 25, ip: anyIPAddress, reuse: true, want: nil},
+ {port: 25, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
},
}, {
tname: "bind to inaddr any with reuseport",
actions: []portReserveTestAction{
- {port: 24, ip: anyIPAddress, reuse: true, want: nil},
- {port: 24, ip: anyIPAddress, reuse: true, want: nil},
+ {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse},
+ {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, reuse: true, want: nil},
- {port: 24, ip: fakeIPAddress, release: true, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, release: true, want: nil},
- {port: 24, ip: anyIPAddress, release: true},
- {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse},
+ {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true},
+ {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
- {port: 24, ip: anyIPAddress, release: true},
- {port: 24, ip: anyIPAddress, reuse: false, want: nil},
+ {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true},
+ {port: 24, ip: anyIPAddress, flags: Flags{}, want: nil},
},
}, {
tname: "bind twice with device fails",
@@ -125,88 +125,152 @@ func TestPortReservation(t *testing.T) {
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, want: nil},
{port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
{port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
},
}, {
tname: "bind with device",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, device: 123, want: nil},
{port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
{port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 789, want: nil},
{port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
},
}, {
- tname: "bind with reuse",
+ tname: "bind with reuseport",
actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
},
}, {
- tname: "binding with reuse and device",
+ tname: "binding with reuseport and device",
actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
- {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse},
},
}, {
- tname: "mixing reuse and not reuse by binding to device",
+ tname: "mixing reuseport and not reuseport by binding to device",
actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 456, want: nil},
- {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 999, want: nil},
},
}, {
- tname: "can't bind to 0 after mixing reuse and not reuse",
+ tname: "can't bind to 0 after mixing reuseport and not reuseport",
actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 456, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
},
}, {
tname: "bind and release",
actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
- {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil},
// Release the bind to device 0 and try again.
- {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil, release: true},
- {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil},
},
}, {
- tname: "bind twice with reuse once",
+ tname: "bind twice with reuseport once",
actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
},
}, {
tname: "release an unreserved device",
actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil},
- {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil},
// The below don't exist.
- {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil, release: true},
- {port: 9999, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil, release: true},
+ {port: 9999, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true},
// Release all.
- {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true},
- {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil, release: true},
+ },
+ }, {
+ tname: "bind with reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{MostRecent: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: nil},
+ },
+ }, {
+ tname: "bind twice with reuseaddr once",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with reuseaddr and reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ },
+ }, {
+ tname: "bind with reuseaddr and reuseport, and then reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with reuseaddr and reuseport, and then reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with reuseaddr and reuseport twice, and then reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
+ },
+ }, {
+ tname: "bind with reuseaddr and reuseport twice, and then reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ },
+ }, {
+ tname: "bind with reuseaddr, and then reuseaddr and reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with reuseport, and then reuseaddr and reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse},
},
},
} {
@@ -216,12 +280,12 @@ func TestPortReservation(t *testing.T) {
for _, test := range test.actions {
if test.release {
- pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.device)
+ pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device)
continue
}
- gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse, test.device)
+ gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device)
if err != test.want {
- t.Fatalf("ReservePort(.., .., %s, %d, %t, %d) = %v, want %v", test.ip, test.port, test.reuse, test.device, err, test.want)
+ t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d) = %v, want %v", test.ip, test.port, test.flags, test.device, err, test.want)
}
if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) {
t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD
index a57752a7c..d7496fde6 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/BUILD
+++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD
@@ -5,6 +5,7 @@ package(licenses = ["notice"])
go_binary(
name = "tun_tcp_connect",
srcs = ["main.go"],
+ visibility = ["//:sandbox"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/sample/tun_tcp_echo/BUILD b/pkg/tcpip/sample/tun_tcp_echo/BUILD
index dad8ef399..875561566 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/BUILD
+++ b/pkg/tcpip/sample/tun_tcp_echo/BUILD
@@ -5,6 +5,7 @@ package(licenses = ["notice"])
go_binary(
name = "tun_tcp_echo",
srcs = ["main.go"],
+ visibility = ["//:sandbox"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/link/fdbased",
diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD
index 29b7d761c..b31ddba2f 100644
--- a/pkg/tcpip/seqnum/BUILD
+++ b/pkg/tcpip/seqnum/BUILD
@@ -6,7 +6,5 @@ go_library(
name = "seqnum",
srcs = ["seqnum.go"],
importpath = "gvisor.dev/gvisor/pkg/tcpip/seqnum",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
)
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index bfc03e90b..69077669a 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -31,9 +31,7 @@ go_library(
"transport_demuxer.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/stack",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/ilist",
"//pkg/rand",
@@ -73,6 +71,7 @@ go_test(
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
+ "@com_github_google_go-cmp//cmp:go_default_library",
],
)
@@ -86,11 +85,3 @@ go_test(
"//pkg/tcpip",
],
)
-
-filegroup(
- name = "autogen",
- srcs = [
- "linkaddrentry_list.go",
- ],
- visibility = ["//:sandbox"],
-)
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index bed60d7b1..90664ba8a 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -38,6 +38,34 @@ const (
// Default = 1s (from RFC 4861 section 10).
defaultRetransmitTimer = time.Second
+ // defaultHandleRAs is the default configuration for whether or not to
+ // handle incoming Router Advertisements as a host.
+ //
+ // Default = true.
+ defaultHandleRAs = true
+
+ // defaultDiscoverDefaultRouters is the default configuration for
+ // whether or not to discover default routers from incoming Router
+ // Advertisements, as a host.
+ //
+ // Default = true.
+ defaultDiscoverDefaultRouters = true
+
+ // defaultDiscoverOnLinkPrefixes is the default configuration for
+ // whether or not to discover on-link prefixes from incoming Router
+ // Advertisements' Prefix Information option, as a host.
+ //
+ // Default = true.
+ defaultDiscoverOnLinkPrefixes = true
+
+ // defaultAutoGenGlobalAddresses is the default configuration for
+ // whether or not to generate global IPv6 addresses in response to
+ // receiving a new Prefix Information option with its Autonomous
+ // Address AutoConfiguration flag set, as a host.
+ //
+ // Default = true.
+ defaultAutoGenGlobalAddresses = true
+
// minimumRetransmitTimer is the minimum amount of time to wait between
// sending NDP Neighbor solicitation messages. Note, RFC 4861 does
// not impose a minimum Retransmit Timer, but we do here to make sure
@@ -49,8 +77,117 @@ const (
//
// Min = 1ms.
minimumRetransmitTimer = time.Millisecond
+
+ // MaxDiscoveredDefaultRouters is the maximum number of discovered
+ // default routers. The stack should stop discovering new routers after
+ // discovering MaxDiscoveredDefaultRouters routers.
+ //
+ // This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and
+ // SHOULD be more.
+ //
+ // Max = 10.
+ MaxDiscoveredDefaultRouters = 10
+
+ // MaxDiscoveredOnLinkPrefixes is the maximum number of discovered
+ // on-link prefixes. The stack should stop discovering new on-link
+ // prefixes after discovering MaxDiscoveredOnLinkPrefixes on-link
+ // prefixes.
+ //
+ // Max = 10.
+ MaxDiscoveredOnLinkPrefixes = 10
+
+ // validPrefixLenForAutoGen is the expected prefix length that an
+ // address can be generated for. Must be 64 bits as the interface
+ // identifier (IID) is 64 bits and an IPv6 address is 128 bits, so
+ // 128 - 64 = 64.
+ validPrefixLenForAutoGen = 64
)
+var (
+ // MinPrefixInformationValidLifetimeForUpdate is the minimum Valid
+ // Lifetime to update the valid lifetime of a generated address by
+ // SLAAC.
+ //
+ // This is exported as a variable (instead of a constant) so tests
+ // can update it to a smaller value.
+ //
+ // Min = 2hrs.
+ MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour
+)
+
+// NDPDispatcher is the interface integrators of netstack must implement to
+// receive and handle NDP related events.
+type NDPDispatcher interface {
+ // OnDuplicateAddressDetectionStatus will be called when the DAD process
+ // for an address (addr) on a NIC (with ID nicID) completes. resolved
+ // will be set to true if DAD completed successfully (no duplicate addr
+ // detected); false otherwise (addr was detected to be a duplicate on
+ // the link the NIC is a part of, or it was stopped for some other
+ // reason, such as the address being removed). If an error occured
+ // during DAD, err will be set and resolved must be ignored.
+ //
+ // This function is permitted to block indefinitely without interfering
+ // with the stack's operation.
+ OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error)
+
+ // OnDefaultRouterDiscovered will be called when a new default router is
+ // discovered. Implementations must return true if the newly discovered
+ // router should be remembered.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool
+
+ // OnDefaultRouterInvalidated will be called when a discovered default
+ // router that was remembered is invalidated.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address)
+
+ // OnOnLinkPrefixDiscovered will be called when a new on-link prefix is
+ // discovered. Implementations must return true if the newly discovered
+ // on-link prefix should be remembered.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool
+
+ // OnOnLinkPrefixInvalidated will be called when a discovered on-link
+ // prefix that was remembered is invalidated.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet)
+
+ // OnAutoGenAddress will be called when a new prefix with its
+ // autonomous address-configuration flag set has been received and SLAAC
+ // has been performed. Implementations may prevent the stack from
+ // assigning the address to the NIC by returning false.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool
+
+ // OnAutoGenAddressInvalidated will be called when an auto-generated
+ // address (as part of SLAAC) has been invalidated.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix)
+
+ // OnRecursiveDNSServerOption will be called when an NDP option with
+ // recursive DNS servers has been received. Note, addrs may contain
+ // link-local addresses.
+ //
+ // It is up to the caller to use the DNS Servers only for their valid
+ // lifetime. OnRecursiveDNSServerOption may be called for new or
+ // already known DNS servers. If called with known DNS servers, their
+ // valid lifetimes must be refreshed to lifetime (it may be increased,
+ // decreased, or completely invalidated when lifetime = 0).
+ OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration)
+}
+
// NDPConfigurations is the NDP configurations for the netstack.
type NDPConfigurations struct {
// The number of Neighbor Solicitation messages to send when doing
@@ -64,6 +201,31 @@ type NDPConfigurations struct {
//
// Must be greater than 0.5s.
RetransmitTimer time.Duration
+
+ // HandleRAs determines whether or not Router Advertisements will be
+ // processed.
+ HandleRAs bool
+
+ // DiscoverDefaultRouters determines whether or not default routers will
+ // be discovered from Router Advertisements. This configuration is
+ // ignored if HandleRAs is false.
+ DiscoverDefaultRouters bool
+
+ // DiscoverOnLinkPrefixes determines whether or not on-link prefixes
+ // will be discovered from Router Advertisements' Prefix Information
+ // option. This configuration is ignored if HandleRAs is false.
+ DiscoverOnLinkPrefixes bool
+
+ // AutoGenGlobalAddresses determines whether or not global IPv6
+ // addresses will be generated for a NIC in response to receiving a new
+ // Prefix Information option with its Autonomous Address
+ // AutoConfiguration flag set, as a host, as per RFC 4862 (SLAAC).
+ //
+ // Note, if an address was already generated for some unique prefix, as
+ // part of SLAAC, this option does not affect whether or not the
+ // lifetime(s) of the generated address changes; this option only
+ // affects the generation of new addresses as part of SLAAC.
+ AutoGenGlobalAddresses bool
}
// DefaultNDPConfigurations returns an NDPConfigurations populated with
@@ -72,6 +234,10 @@ func DefaultNDPConfigurations() NDPConfigurations {
return NDPConfigurations{
DupAddrDetectTransmits: defaultDupAddrDetectTransmits,
RetransmitTimer: defaultRetransmitTimer,
+ HandleRAs: defaultHandleRAs,
+ DiscoverDefaultRouters: defaultDiscoverDefaultRouters,
+ DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes,
+ AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses,
}
}
@@ -88,8 +254,24 @@ func (c *NDPConfigurations) validate() {
// ndpState is the per-interface NDP state.
type ndpState struct {
+ // The NIC this ndpState is for.
+ nic *NIC
+
+ // configs is the per-interface NDP configurations.
+ configs NDPConfigurations
+
// The DAD state to send the next NS message, or resolve the address.
dad map[tcpip.Address]dadState
+
+ // The default routers discovered through Router Advertisements.
+ defaultRouters map[tcpip.Address]defaultRouterState
+
+ // The on-link prefixes discovered through Router Advertisements' Prefix
+ // Information option.
+ onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState
+
+ // The addresses generated by SLAAC.
+ autoGenAddresses map[tcpip.Address]autoGenAddressState
}
// dadState holds the Duplicate Address Detection timer and channel to signal
@@ -105,13 +287,84 @@ type dadState struct {
done *bool
}
+// defaultRouterState holds data associated with a default router discovered by
+// a Router Advertisement (RA).
+type defaultRouterState struct {
+ invalidationTimer *time.Timer
+
+ // Used to inform the timer not to invalidate the default router (R) in
+ // a race condition (T1 is a goroutine that handles an RA from R and T2
+ // is the goroutine that handles R's invalidation timer firing):
+ // T1: Receive a new RA from R
+ // T1: Obtain the NIC's lock before processing the RA
+ // T2: R's invalidation timer fires, and gets blocked on obtaining the
+ // NIC's lock
+ // T1: Refreshes/extends R's lifetime & releases NIC's lock
+ // T2: Obtains NIC's lock & invalidates R immediately
+ //
+ // To resolve this, T1 will check to see if the timer already fired, and
+ // inform the timer using doNotInvalidate to not invalidate R, so that
+ // once T2 obtains the lock, it will see that it is set to true and do
+ // nothing further.
+ doNotInvalidate *bool
+}
+
+// onLinkPrefixState holds data associated with an on-link prefix discovered by
+// a Router Advertisement's Prefix Information option (PI) when the NDP
+// configurations was configured to do so.
+type onLinkPrefixState struct {
+ invalidationTimer *time.Timer
+
+ // Used to signal the timer not to invalidate the on-link prefix (P) in
+ // a race condition (T1 is a goroutine that handles a PI for P and T2
+ // is the goroutine that handles P's invalidation timer firing):
+ // T1: Receive a new PI for P
+ // T1: Obtain the NIC's lock before processing the PI
+ // T2: P's invalidation timer fires, and gets blocked on obtaining the
+ // NIC's lock
+ // T1: Refreshes/extends P's lifetime & releases NIC's lock
+ // T2: Obtains NIC's lock & invalidates P immediately
+ //
+ // To resolve this, T1 will check to see if the timer already fired, and
+ // inform the timer using doNotInvalidate to not invalidate P, so that
+ // once T2 obtains the lock, it will see that it is set to true and do
+ // nothing further.
+ doNotInvalidate *bool
+}
+
+// autoGenAddressState holds data associated with an address generated via
+// SLAAC.
+type autoGenAddressState struct {
+ invalidationTimer *time.Timer
+
+ // Used to signal the timer not to invalidate the SLAAC address (A) in
+ // a race condition (T1 is a goroutine that handles a PI for A and T2
+ // is the goroutine that handles A's invalidation timer firing):
+ // T1: Receive a new PI for A
+ // T1: Obtain the NIC's lock before processing the PI
+ // T2: A's invalidation timer fires, and gets blocked on obtaining the
+ // NIC's lock
+ // T1: Refreshes/extends A's lifetime & releases NIC's lock
+ // T2: Obtains NIC's lock & invalidates A immediately
+ //
+ // To resolve this, T1 will check to see if the timer already fired, and
+ // inform the timer using doNotInvalidate to not invalidate A, so that
+ // once T2 obtains the lock, it will see that it is set to true and do
+ // nothing further.
+ doNotInvalidate *bool
+
+ // Nonzero only when the address is not valid forever (invalidationTimer
+ // is not nil).
+ validUntil time.Time
+}
+
// startDuplicateAddressDetection performs Duplicate Address Detection.
//
// This function must only be called by IPv6 addresses that are currently
// tentative.
//
-// The NIC that ndp belongs to (n) MUST be locked.
-func (ndp *ndpState) startDuplicateAddressDetection(n *NIC, addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error {
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error {
// addr must be a valid unicast IPv6 address.
if !header.IsV6UnicastAddress(addr) {
return tcpip.ErrAddressFamilyNotSupported
@@ -127,13 +380,13 @@ func (ndp *ndpState) startDuplicateAddressDetection(n *NIC, addr tcpip.Address,
// reference count would have been increased without doing the
// work that would have been done for an address that was brand
// new. See NIC.addPermanentAddressLocked.
- panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, n.ID()))
+ panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID()))
}
- remaining := n.stack.ndpConfigs.DupAddrDetectTransmits
+ remaining := ndp.configs.DupAddrDetectTransmits
{
- done, err := ndp.doDuplicateAddressDetection(n, addr, remaining, ref)
+ done, err := ndp.doDuplicateAddressDetection(addr, remaining, ref)
if err != nil {
return err
}
@@ -146,42 +399,59 @@ func (ndp *ndpState) startDuplicateAddressDetection(n *NIC, addr tcpip.Address,
var done bool
var timer *time.Timer
- timer = time.AfterFunc(n.stack.ndpConfigs.RetransmitTimer, func() {
- n.mu.Lock()
- defer n.mu.Unlock()
+ timer = time.AfterFunc(ndp.configs.RetransmitTimer, func() {
+ var d bool
+ var err *tcpip.Error
- if done {
- // If we reach this point, it means that the DAD timer
- // fired after another goroutine already obtained the
- // NIC lock and stopped DAD before it this function
- // obtained the NIC lock. Simply return here and do
- // nothing further.
- return
- }
+ // doDadIteration does a single iteration of the DAD loop.
+ //
+ // Returns true if the integrator needs to be informed of DAD
+ // completing.
+ doDadIteration := func() bool {
+ ndp.nic.mu.Lock()
+ defer ndp.nic.mu.Unlock()
- ref, ok := n.endpoints[NetworkEndpointID{addr}]
- if !ok {
- // This should never happen.
- // We should have an endpoint for addr since we are
- // still performing DAD on it. If the endpoint does not
- // exist, but we are doing DAD on it, then we started
- // DAD at some point, but forgot to stop it when the
- // endpoint was deleted.
- panic(fmt.Sprintf("ndpdad: unrecognized addr %s for NIC(%d)", addr, n.ID()))
- }
+ if done {
+ // If we reach this point, it means that the DAD
+ // timer fired after another goroutine already
+ // obtained the NIC lock and stopped DAD before
+ // this function obtained the NIC lock. Simply
+ // return here and do nothing further.
+ return false
+ }
- if done, err := ndp.doDuplicateAddressDetection(n, addr, remaining, ref); err != nil || done {
- if err != nil {
- log.Printf("ndpdad: Error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, n.ID(), err)
+ ref, ok := ndp.nic.endpoints[NetworkEndpointID{addr}]
+ if !ok {
+ // This should never happen.
+ // We should have an endpoint for addr since we
+ // are still performing DAD on it. If the
+ // endpoint does not exist, but we are doing DAD
+ // on it, then we started DAD at some point, but
+ // forgot to stop it when the endpoint was
+ // deleted.
+ panic(fmt.Sprintf("ndpdad: unrecognized addr %s for NIC(%d)", addr, ndp.nic.ID()))
}
- ndp.stopDuplicateAddressDetection(addr)
- return
- }
+ d, err = ndp.doDuplicateAddressDetection(addr, remaining, ref)
+ if err != nil || d {
+ delete(ndp.dad, addr)
- timer.Reset(n.stack.ndpConfigs.RetransmitTimer)
- remaining--
+ if err != nil {
+ log.Printf("ndpdad: Error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err)
+ }
+ // Let the integrator know DAD has completed.
+ return true
+ }
+
+ remaining--
+ timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer)
+ return false
+ }
+
+ if doDadIteration() && ndp.nic.stack.ndpDisp != nil {
+ ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, d, err)
+ }
})
ndp.dad[addr] = dadState{
@@ -204,11 +474,11 @@ func (ndp *ndpState) startDuplicateAddressDetection(n *NIC, addr tcpip.Address,
// The NIC that ndp belongs to (n) MUST be locked.
//
// Returns true if DAD has resolved; false if DAD is still ongoing.
-func (ndp *ndpState) doDuplicateAddressDetection(n *NIC, addr tcpip.Address, remaining uint8, ref *referencedNetworkEndpoint) (bool, *tcpip.Error) {
+func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining uint8, ref *referencedNetworkEndpoint) (bool, *tcpip.Error) {
if ref.getKind() != permanentTentative {
// The endpoint should still be marked as tentative
// since we are still performing DAD on it.
- panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, n.ID()))
+ panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID()))
}
if remaining == 0 {
@@ -219,17 +489,17 @@ func (ndp *ndpState) doDuplicateAddressDetection(n *NIC, addr tcpip.Address, rem
// Send a new NS.
snmc := header.SolicitedNodeAddr(addr)
- snmcRef, ok := n.endpoints[NetworkEndpointID{snmc}]
+ snmcRef, ok := ndp.nic.endpoints[NetworkEndpointID{snmc}]
if !ok {
// This should never happen as if we have the
// address, we should have the solicited-node
// address.
- panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", n.ID(), snmc, addr))
+ panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", ndp.nic.ID(), snmc, addr))
}
// Use the unspecified address as the source address when performing
// DAD.
- r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, n.linkEP.LinkAddress(), snmcRef, false, false)
+ r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), snmcRef, false, false)
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
@@ -239,7 +509,9 @@ func (ndp *ndpState) doDuplicateAddressDetection(n *NIC, addr tcpip.Address, rem
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
sent := r.Stats().ICMP.V6PacketsSent
- if err := r.WritePacket(nil, hdr, buffer.VectorisedView{}, NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: DefaultTOS}); err != nil {
+ if err := r.WritePacket(nil, NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ }); err != nil {
sent.Dropped.Increment()
return false, err
}
@@ -275,5 +547,611 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
delete(ndp.dad, addr)
- return
+ // Let the integrator know DAD did not resolve.
+ if ndp.nic.stack.ndpDisp != nil {
+ go ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil)
+ }
+}
+
+// handleRA handles a Router Advertisement message that arrived on the NIC
+// this ndp is for. Does nothing if the NIC is configured to not handle RAs.
+//
+// The NIC that ndp belongs to and its associated stack MUST be locked.
+func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
+ // Is the NIC configured to handle RAs at all?
+ //
+ // Currently, the stack does not determine router interface status on a
+ // per-interface basis; it is a stack-wide configuration, so we check
+ // stack's forwarding flag to determine if the NIC is a routing
+ // interface.
+ if !ndp.configs.HandleRAs || ndp.nic.stack.forwarding {
+ return
+ }
+
+ // Is the NIC configured to discover default routers?
+ if ndp.configs.DiscoverDefaultRouters {
+ rtr, ok := ndp.defaultRouters[ip]
+ rl := ra.RouterLifetime()
+ switch {
+ case !ok && rl != 0:
+ // This is a new default router we are discovering.
+ //
+ // Only remember it if we currently know about less than
+ // MaxDiscoveredDefaultRouters routers.
+ if len(ndp.defaultRouters) < MaxDiscoveredDefaultRouters {
+ ndp.rememberDefaultRouter(ip, rl)
+ }
+
+ case ok && rl != 0:
+ // This is an already discovered default router. Update
+ // the invalidation timer.
+ timer := rtr.invalidationTimer
+
+ // We should ALWAYS have an invalidation timer for a
+ // discovered router.
+ if timer == nil {
+ panic("ndphandlera: RA invalidation timer should not be nil")
+ }
+
+ if !timer.Stop() {
+ // If we reach this point, then we know the
+ // timer fired after we already took the NIC
+ // lock. Inform the timer not to invalidate the
+ // router when it obtains the lock as we just
+ // got a new RA that refreshes its lifetime to a
+ // non-zero value. See
+ // defaultRouterState.doNotInvalidate for more
+ // details.
+ *rtr.doNotInvalidate = true
+ }
+
+ timer.Reset(rl)
+
+ case ok && rl == 0:
+ // We know about the router but it is no longer to be
+ // used as a default router so invalidate it.
+ ndp.invalidateDefaultRouter(ip)
+ }
+ }
+
+ // TODO(b/141556115): Do (RetransTimer, ReachableTime)) Parameter
+ // Discovery.
+
+ // We know the options is valid as far as wire format is concerned since
+ // we got the Router Advertisement, as documented by this fn. Given this
+ // we do not check the iterator for errors on calls to Next.
+ it, _ := ra.Options().Iter(false)
+ for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() {
+ switch opt := opt.(type) {
+ case header.NDPRecursiveDNSServer:
+ if ndp.nic.stack.ndpDisp == nil {
+ continue
+ }
+
+ ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), opt.Addresses(), opt.Lifetime())
+
+ case header.NDPPrefixInformation:
+ prefix := opt.Subnet()
+
+ // Is the prefix a link-local?
+ if header.IsV6LinkLocalAddress(prefix.ID()) {
+ // ...Yes, skip as per RFC 4861 section 6.3.4,
+ // and RFC 4862 section 5.5.3.b (for SLAAC).
+ continue
+ }
+
+ // Is the Prefix Length 0?
+ if prefix.Prefix() == 0 {
+ // ...Yes, skip as this is an invalid prefix
+ // as all IPv6 addresses cannot be on-link.
+ continue
+ }
+
+ if opt.OnLinkFlag() {
+ ndp.handleOnLinkPrefixInformation(opt)
+ }
+
+ if opt.AutonomousAddressConfigurationFlag() {
+ ndp.handleAutonomousPrefixInformation(opt)
+ }
+ }
+
+ // TODO(b/141556115): Do (MTU) Parameter Discovery.
+ }
+}
+
+// invalidateDefaultRouter invalidates a discovered default router.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
+ rtr, ok := ndp.defaultRouters[ip]
+
+ // Is the router still discovered?
+ if !ok {
+ // ...Nope, do nothing further.
+ return
+ }
+
+ rtr.invalidationTimer.Stop()
+ rtr.invalidationTimer = nil
+ *rtr.doNotInvalidate = true
+ rtr.doNotInvalidate = nil
+
+ delete(ndp.defaultRouters, ip)
+
+ // Let the integrator know a discovered default router is invalidated.
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip)
+ }
+}
+
+// rememberDefaultRouter remembers a newly discovered default router with IPv6
+// link-local address ip with lifetime rl.
+//
+// The router identified by ip MUST NOT already be known by the NIC.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
+ ndpDisp := ndp.nic.stack.ndpDisp
+ if ndpDisp == nil {
+ return
+ }
+
+ // Inform the integrator when we discovered a default router.
+ if !ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip) {
+ // Informed by the integrator to not remember the router, do
+ // nothing further.
+ return
+ }
+
+ // Used to signal the timer not to invalidate the default router (R) in
+ // a race condition. See defaultRouterState.doNotInvalidate for more
+ // details.
+ var doNotInvalidate bool
+
+ ndp.defaultRouters[ip] = defaultRouterState{
+ invalidationTimer: time.AfterFunc(rl, func() {
+ ndp.nic.stack.mu.Lock()
+ defer ndp.nic.stack.mu.Unlock()
+ ndp.nic.mu.Lock()
+ defer ndp.nic.mu.Unlock()
+
+ if doNotInvalidate {
+ doNotInvalidate = false
+ return
+ }
+
+ ndp.invalidateDefaultRouter(ip)
+ }),
+ doNotInvalidate: &doNotInvalidate,
+ }
+}
+
+// rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6
+// address with prefix prefix with lifetime l.
+//
+// The prefix identified by prefix MUST NOT already be known.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) {
+ ndpDisp := ndp.nic.stack.ndpDisp
+ if ndpDisp == nil {
+ return
+ }
+
+ // Inform the integrator when we discovered an on-link prefix.
+ if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.nic.ID(), prefix) {
+ // Informed by the integrator to not remember the prefix, do
+ // nothing further.
+ return
+ }
+
+ // Used to signal the timer not to invalidate the on-link prefix (P) in
+ // a race condition. See onLinkPrefixState.doNotInvalidate for more
+ // details.
+ var doNotInvalidate bool
+ var timer *time.Timer
+
+ // Only create a timer if the lifetime is not infinite.
+ if l < header.NDPInfiniteLifetime {
+ timer = ndp.prefixInvalidationCallback(prefix, l, &doNotInvalidate)
+ }
+
+ ndp.onLinkPrefixes[prefix] = onLinkPrefixState{
+ invalidationTimer: timer,
+ doNotInvalidate: &doNotInvalidate,
+ }
+}
+
+// invalidateOnLinkPrefix invalidates a discovered on-link prefix.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
+ s, ok := ndp.onLinkPrefixes[prefix]
+
+ // Is the on-link prefix still discovered?
+ if !ok {
+ // ...Nope, do nothing further.
+ return
+ }
+
+ if s.invalidationTimer != nil {
+ s.invalidationTimer.Stop()
+ s.invalidationTimer = nil
+ *s.doNotInvalidate = true
+ }
+
+ s.doNotInvalidate = nil
+
+ delete(ndp.onLinkPrefixes, prefix)
+
+ // Let the integrator know a discovered on-link prefix is invalidated.
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnOnLinkPrefixInvalidated(ndp.nic.ID(), prefix)
+ }
+}
+
+// prefixInvalidationCallback returns a new on-link prefix invalidation timer
+// for prefix that fires after vl.
+//
+// doNotInvalidate is used to signal the timer when it fires at the same time
+// that a prefix's valid lifetime gets refreshed. See
+// onLinkPrefixState.doNotInvalidate for more details.
+func (ndp *ndpState) prefixInvalidationCallback(prefix tcpip.Subnet, vl time.Duration, doNotInvalidate *bool) *time.Timer {
+ return time.AfterFunc(vl, func() {
+ ndp.nic.stack.mu.Lock()
+ defer ndp.nic.stack.mu.Unlock()
+ ndp.nic.mu.Lock()
+ defer ndp.nic.mu.Unlock()
+
+ if *doNotInvalidate {
+ *doNotInvalidate = false
+ return
+ }
+
+ ndp.invalidateOnLinkPrefix(prefix)
+ })
+}
+
+// handleOnLinkPrefixInformation handles a Prefix Information option with
+// its on-link flag set, as per RFC 4861 section 6.3.4.
+//
+// handleOnLinkPrefixInformation assumes that the prefix this pi is for is
+// not the link-local prefix and the on-link flag is set.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) {
+ prefix := pi.Subnet()
+ prefixState, ok := ndp.onLinkPrefixes[prefix]
+ vl := pi.ValidLifetime()
+
+ if !ok && vl == 0 {
+ // Don't know about this prefix but it has a zero valid
+ // lifetime, so just ignore.
+ return
+ }
+
+ if !ok && vl != 0 {
+ // This is a new on-link prefix we are discovering
+ //
+ // Only remember it if we currently know about less than
+ // MaxDiscoveredOnLinkPrefixes on-link prefixes.
+ if ndp.configs.DiscoverOnLinkPrefixes && len(ndp.onLinkPrefixes) < MaxDiscoveredOnLinkPrefixes {
+ ndp.rememberOnLinkPrefix(prefix, vl)
+ }
+ return
+ }
+
+ if ok && vl == 0 {
+ // We know about the on-link prefix, but it is
+ // no longer to be considered on-link, so
+ // invalidate it.
+ ndp.invalidateOnLinkPrefix(prefix)
+ return
+ }
+
+ // This is an already discovered on-link prefix with a
+ // new non-zero valid lifetime.
+ // Update the invalidation timer.
+ timer := prefixState.invalidationTimer
+
+ if timer == nil && vl >= header.NDPInfiniteLifetime {
+ // Had infinite valid lifetime before and
+ // continues to have an invalid lifetime. Do
+ // nothing further.
+ return
+ }
+
+ if timer != nil && !timer.Stop() {
+ // If we reach this point, then we know the timer alread fired
+ // after we took the NIC lock. Inform the timer to not
+ // invalidate the prefix once it obtains the lock as we just
+ // got a new PI that refreshes its lifetime to a non-zero value.
+ // See onLinkPrefixState.doNotInvalidate for more details.
+ *prefixState.doNotInvalidate = true
+ }
+
+ if vl >= header.NDPInfiniteLifetime {
+ // Prefix is now valid forever so we don't need
+ // an invalidation timer.
+ prefixState.invalidationTimer = nil
+ ndp.onLinkPrefixes[prefix] = prefixState
+ return
+ }
+
+ if timer != nil {
+ // We already have a timer so just reset it to
+ // expire after the new valid lifetime.
+ timer.Reset(vl)
+ return
+ }
+
+ // We do not have a timer so just create a new one.
+ prefixState.invalidationTimer = ndp.prefixInvalidationCallback(prefix, vl, prefixState.doNotInvalidate)
+ ndp.onLinkPrefixes[prefix] = prefixState
+}
+
+// handleAutonomousPrefixInformation handles a Prefix Information option with
+// its autonomous flag set, as per RFC 4862 section 5.5.3.
+//
+// handleAutonomousPrefixInformation assumes that the prefix this pi is for is
+// not the link-local prefix and the autonomous flag is set.
+//
+// The NIC that ndp belongs to and its associated stack MUST be locked.
+func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) {
+ vl := pi.ValidLifetime()
+ pl := pi.PreferredLifetime()
+
+ // If the preferred lifetime is greater than the valid lifetime,
+ // silently ignore the Prefix Information option, as per RFC 4862
+ // section 5.5.3.c.
+ if pl > vl {
+ return
+ }
+
+ prefix := pi.Subnet()
+
+ // Check if we already have an auto-generated address for prefix.
+ for _, ref := range ndp.nic.endpoints {
+ if ref.protocol != header.IPv6ProtocolNumber {
+ continue
+ }
+
+ if ref.configType != slaac {
+ continue
+ }
+
+ addr := ref.ep.ID().LocalAddress
+ refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: ref.ep.PrefixLen()}
+ if refAddrWithPrefix.Subnet() != prefix {
+ continue
+ }
+
+ //
+ // At this point, we know we are refreshing a SLAAC generated
+ // IPv6 address with the prefix, prefix. Do the work as outlined
+ // by RFC 4862 section 5.5.3.e.
+ //
+
+ addrState, ok := ndp.autoGenAddresses[addr]
+ if !ok {
+ panic(fmt.Sprintf("must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr))
+ }
+
+ // TODO(b/143713887): Handle deprecating auto-generated address
+ // after the preferred lifetime.
+
+ // As per RFC 4862 section 5.5.3.e, the valid lifetime of the
+ // address generated by SLAAC is as follows:
+ //
+ // 1) If the received Valid Lifetime is greater than 2 hours or
+ // greater than RemainingLifetime, set the valid lifetime of
+ // the address to the advertised Valid Lifetime.
+ //
+ // 2) If RemainingLifetime is less than or equal to 2 hours,
+ // ignore the advertised Valid Lifetime.
+ //
+ // 3) Otherwise, reset the valid lifetime of the address to 2
+ // hours.
+
+ // Handle the infinite valid lifetime separately as we do not
+ // keep a timer in this case.
+ if vl >= header.NDPInfiniteLifetime {
+ if addrState.invalidationTimer != nil {
+ // Valid lifetime was finite before, but now it
+ // is valid forever.
+ if !addrState.invalidationTimer.Stop() {
+ *addrState.doNotInvalidate = true
+ }
+ addrState.invalidationTimer = nil
+ addrState.validUntil = time.Time{}
+ ndp.autoGenAddresses[addr] = addrState
+ }
+
+ return
+ }
+
+ var effectiveVl time.Duration
+ var rl time.Duration
+
+ // If the address was originally set to be valid forever,
+ // assume the remaining time to be the maximum possible value.
+ if addrState.invalidationTimer == nil {
+ rl = header.NDPInfiniteLifetime
+ } else {
+ rl = time.Until(addrState.validUntil)
+ }
+
+ if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl {
+ effectiveVl = vl
+ } else if rl <= MinPrefixInformationValidLifetimeForUpdate {
+ ndp.autoGenAddresses[addr] = addrState
+ return
+ } else {
+ effectiveVl = MinPrefixInformationValidLifetimeForUpdate
+ }
+
+ if addrState.invalidationTimer == nil {
+ addrState.invalidationTimer = ndp.autoGenAddrInvalidationTimer(addr, effectiveVl, addrState.doNotInvalidate)
+ } else {
+ if !addrState.invalidationTimer.Stop() {
+ *addrState.doNotInvalidate = true
+ }
+ addrState.invalidationTimer.Reset(effectiveVl)
+ }
+
+ addrState.validUntil = time.Now().Add(effectiveVl)
+ ndp.autoGenAddresses[addr] = addrState
+ return
+ }
+
+ // We do not already have an address within the prefix, prefix. Do the
+ // work as outlined by RFC 4862 section 5.5.3.d if n is configured
+ // to auto-generated global addresses by SLAAC.
+
+ // Are we configured to auto-generate new global addresses?
+ if !ndp.configs.AutoGenGlobalAddresses {
+ return
+ }
+
+ // If we do not already have an address for this prefix and the valid
+ // lifetime is 0, no need to do anything further, as per RFC 4862
+ // section 5.5.3.d.
+ if vl == 0 {
+ return
+ }
+
+ // Make sure the prefix is valid (as far as its length is concerned) to
+ // generate a valid IPv6 address from an interface identifier (IID), as
+ // per RFC 4862 sectiion 5.5.3.d.
+ if prefix.Prefix() != validPrefixLenForAutoGen {
+ return
+ }
+
+ // Only attempt to generate an interface-specific IID if we have a valid
+ // link address.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address
+ // (provided by LinkEndpoint.LinkAddress) before reaching this
+ // point.
+ linkAddr := ndp.nic.linkEP.LinkAddress()
+ if !header.IsValidUnicastEthernetAddress(linkAddr) {
+ return
+ }
+
+ // Generate an address within prefix from the modified EUI-64 of ndp's
+ // NIC's Ethernet MAC address.
+ addrBytes := make([]byte, header.IPv6AddressSize)
+ copy(addrBytes[:header.IIDOffsetInIPv6Address], prefix.ID()[:header.IIDOffsetInIPv6Address])
+ header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
+ addr := tcpip.Address(addrBytes)
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: addr,
+ PrefixLen: validPrefixLenForAutoGen,
+ }
+
+ // If the nic already has this address, do nothing further.
+ if ndp.nic.hasPermanentAddrLocked(addr) {
+ return
+ }
+
+ // Inform the integrator that we have a new SLAAC address.
+ ndpDisp := ndp.nic.stack.ndpDisp
+ if ndpDisp == nil {
+ return
+ }
+ if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addrWithPrefix) {
+ // Informed by the integrator not to add the address.
+ return
+ }
+
+ if _, err := ndp.nic.addAddressLocked(tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix,
+ }, FirstPrimaryEndpoint, permanent, slaac); err != nil {
+ panic(err)
+ }
+
+ // Setup the timers to deprecate and invalidate this newly generated
+ // address.
+
+ // TODO(b/143713887): Handle deprecating auto-generated addresses
+ // after the preferred lifetime.
+
+ var doNotInvalidate bool
+ var vTimer *time.Timer
+ if vl < header.NDPInfiniteLifetime {
+ vTimer = ndp.autoGenAddrInvalidationTimer(addr, vl, &doNotInvalidate)
+ }
+
+ ndp.autoGenAddresses[addr] = autoGenAddressState{
+ invalidationTimer: vTimer,
+ doNotInvalidate: &doNotInvalidate,
+ validUntil: time.Now().Add(vl),
+ }
+}
+
+// invalidateAutoGenAddress invalidates an auto-generated address.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) invalidateAutoGenAddress(addr tcpip.Address) {
+ if !ndp.cleanupAutoGenAddrResourcesAndNotify(addr) {
+ return
+ }
+
+ ndp.nic.removePermanentAddressLocked(addr)
+}
+
+// cleanupAutoGenAddrResourcesAndNotify cleans up an invalidated auto-generated
+// address's resources from ndp. If the stack has an NDP dispatcher, it will
+// be notified that addr has been invalidated.
+//
+// Returns true if ndp had resources for addr to cleanup.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bool {
+ state, ok := ndp.autoGenAddresses[addr]
+
+ if !ok {
+ return false
+ }
+
+ if state.invalidationTimer != nil {
+ state.invalidationTimer.Stop()
+ state.invalidationTimer = nil
+ *state.doNotInvalidate = true
+ }
+
+ state.doNotInvalidate = nil
+
+ delete(ndp.autoGenAddresses, addr)
+
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), tcpip.AddressWithPrefix{
+ Address: addr,
+ PrefixLen: validPrefixLenForAutoGen,
+ })
+ }
+
+ return true
+}
+
+// autoGenAddrInvalidationTimer returns a new invalidation timer for an
+// auto-generated address that fires after vl.
+//
+// doNotInvalidate is used to inform the timer when it fires at the same time
+// that an auto-generated address's valid lifetime gets refreshed. See
+// autoGenAddrState.doNotInvalidate for more details.
+func (ndp *ndpState) autoGenAddrInvalidationTimer(addr tcpip.Address, vl time.Duration, doNotInvalidate *bool) *time.Timer {
+ return time.AfterFunc(vl, func() {
+ ndp.nic.mu.Lock()
+ defer ndp.nic.mu.Unlock()
+
+ if *doNotInvalidate {
+ *doNotInvalidate = false
+ return
+ }
+
+ ndp.invalidateAutoGenAddress(addr)
+ })
}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index cc789d70f..666f86c33 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -15,9 +15,12 @@
package stack_test
import (
+ "encoding/binary"
+ "fmt"
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -29,12 +32,46 @@ import (
)
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- linkAddr1 = "\x01\x02\x03\x04\x05\x06"
- linkAddr2 = "\x01\x02\x03\x04\x05\x07"
+ addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
+ linkAddr1 = "\x02\x02\x03\x04\x05\x06"
+ linkAddr2 = "\x02\x02\x03\x04\x05\x07"
+ linkAddr3 = "\x02\x02\x03\x04\x05\x08"
+ defaultTimeout = 100 * time.Millisecond
)
+var (
+ llAddr1 = header.LinkLocalAddr(linkAddr1)
+ llAddr2 = header.LinkLocalAddr(linkAddr2)
+ llAddr3 = header.LinkLocalAddr(linkAddr3)
+)
+
+// prefixSubnetAddr returns a prefix (Address + Length), the prefix's equivalent
+// tcpip.Subnet, and an address where the lower half of the address is composed
+// of the EUI-64 of linkAddr if it is a valid unicast ethernet address.
+func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWithPrefix, tcpip.Subnet, tcpip.AddressWithPrefix) {
+ prefixBytes := []byte{1, 2, 3, 4, 5, 6, 7, 8 + offset, 0, 0, 0, 0, 0, 0, 0, 0}
+ prefix := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(prefixBytes),
+ PrefixLen: 64,
+ }
+
+ subnet := prefix.Subnet()
+
+ var addr tcpip.AddressWithPrefix
+ if header.IsValidUnicastEthernetAddress(linkAddr) {
+ addrBytes := []byte(subnet.ID())
+ header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
+ addr = tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes),
+ PrefixLen: 64,
+ }
+ }
+
+ return prefix, subnet, addr
+}
+
// TestDADDisabled tests that an address successfully resolves immediately
// when DAD is not enabled (the default for an empty stack.Options).
func TestDADDisabled(t *testing.T) {
@@ -42,7 +79,7 @@ func TestDADDisabled(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
}
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(0, 1280, linkAddr1)
s := stack.New(opts)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
@@ -68,6 +105,160 @@ func TestDADDisabled(t *testing.T) {
}
}
+// ndpDADEvent is a set of parameters that was passed to
+// ndpDispatcher.OnDuplicateAddressDetectionStatus.
+type ndpDADEvent struct {
+ nicID tcpip.NICID
+ addr tcpip.Address
+ resolved bool
+ err *tcpip.Error
+}
+
+type ndpRouterEvent struct {
+ nicID tcpip.NICID
+ addr tcpip.Address
+ // true if router was discovered, false if invalidated.
+ discovered bool
+}
+
+type ndpPrefixEvent struct {
+ nicID tcpip.NICID
+ prefix tcpip.Subnet
+ // true if prefix was discovered, false if invalidated.
+ discovered bool
+}
+
+type ndpAutoGenAddrEventType int
+
+const (
+ newAddr ndpAutoGenAddrEventType = iota
+ invalidatedAddr
+)
+
+type ndpAutoGenAddrEvent struct {
+ nicID tcpip.NICID
+ addr tcpip.AddressWithPrefix
+ eventType ndpAutoGenAddrEventType
+}
+
+type ndpRDNSS struct {
+ addrs []tcpip.Address
+ lifetime time.Duration
+}
+
+type ndpRDNSSEvent struct {
+ nicID tcpip.NICID
+ rdnss ndpRDNSS
+}
+
+var _ stack.NDPDispatcher = (*ndpDispatcher)(nil)
+
+// ndpDispatcher implements NDPDispatcher so tests can know when various NDP
+// related events happen for test purposes.
+type ndpDispatcher struct {
+ dadC chan ndpDADEvent
+ routerC chan ndpRouterEvent
+ rememberRouter bool
+ prefixC chan ndpPrefixEvent
+ rememberPrefix bool
+ autoGenAddrC chan ndpAutoGenAddrEvent
+ rdnssC chan ndpRDNSSEvent
+}
+
+// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus.
+func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) {
+ if n.dadC != nil {
+ n.dadC <- ndpDADEvent{
+ nicID,
+ addr,
+ resolved,
+ err,
+ }
+ }
+}
+
+// Implements stack.NDPDispatcher.OnDefaultRouterDiscovered.
+func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool {
+ if c := n.routerC; c != nil {
+ c <- ndpRouterEvent{
+ nicID,
+ addr,
+ true,
+ }
+ }
+
+ return n.rememberRouter
+}
+
+// Implements stack.NDPDispatcher.OnDefaultRouterInvalidated.
+func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) {
+ if c := n.routerC; c != nil {
+ c <- ndpRouterEvent{
+ nicID,
+ addr,
+ false,
+ }
+ }
+}
+
+// Implements stack.NDPDispatcher.OnOnLinkPrefixDiscovered.
+func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool {
+ if c := n.prefixC; c != nil {
+ c <- ndpPrefixEvent{
+ nicID,
+ prefix,
+ true,
+ }
+ }
+
+ return n.rememberPrefix
+}
+
+// Implements stack.NDPDispatcher.OnOnLinkPrefixInvalidated.
+func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) {
+ if c := n.prefixC; c != nil {
+ c <- ndpPrefixEvent{
+ nicID,
+ prefix,
+ false,
+ }
+ }
+}
+
+func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool {
+ if c := n.autoGenAddrC; c != nil {
+ c <- ndpAutoGenAddrEvent{
+ nicID,
+ addr,
+ newAddr,
+ }
+ }
+ return true
+}
+
+func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
+ if c := n.autoGenAddrC; c != nil {
+ c <- ndpAutoGenAddrEvent{
+ nicID,
+ addr,
+ invalidatedAddr,
+ }
+ }
+}
+
+// Implements stack.NDPDispatcher.OnRecursiveDNSServerOption.
+func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) {
+ if c := n.rdnssC; c != nil {
+ c <- ndpRDNSSEvent{
+ nicID,
+ ndpRDNSS{
+ addrs,
+ lifetime,
+ },
+ }
+ }
+}
+
// TestDADResolve tests that an address successfully resolves after performing
// DAD for various values of DupAddrDetectTransmits and RetransmitTimer.
// Included in the subtests is a test to make sure that an invalid
@@ -88,9 +279,17 @@ func TestDADResolve(t *testing.T) {
}
for _, test := range tests {
+ test := test
+
t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
}
opts.NDPConfigs.RetransmitTimer = test.retransTimer
opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits
@@ -107,8 +306,7 @@ func TestDADResolve(t *testing.T) {
stat := s.Stats().ICMP.V6PacketsSent.NeighborSolicit
- // Should have sent an NDP NS almost immediately.
- time.Sleep(100 * time.Millisecond)
+ // Should have sent an NDP NS immediately.
if got := stat.Value(); got != 1 {
t.Fatalf("got NeighborSolicit = %d, want = 1", got)
@@ -124,16 +322,10 @@ func TestDADResolve(t *testing.T) {
t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
}
- // Wait for the remaining time - 500ms, to make sure
- // the address is still not resolved. Note, we subtract
- // 600ms because we already waited for 100ms earlier,
- // so our remaining time is 100ms less than the expected
- // time.
- // (X - 100ms) - 500ms = X - 600ms
- //
- // TODO(b/140896005): Use events from the netstack to
- // be signalled before DAD resolves.
- time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - 600*time.Millisecond)
+ // Wait for the remaining time - some delta (500ms), to
+ // make sure the address is still not resolved.
+ const delta = 500 * time.Millisecond
+ time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta)
addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
if err != nil {
t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
@@ -142,13 +334,30 @@ func TestDADResolve(t *testing.T) {
t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
}
- // Wait for the remaining time + 250ms, at which point
- // the address should be resolved. Note, the remaining
- // time is 500ms. See above comments.
- //
- // TODO(b/140896005): Use events from the netstack to
- // know immediately when DAD completes.
- time.Sleep(750 * time.Millisecond)
+ // Wait for DAD to resolve.
+ select {
+ case <-time.After(2 * delta):
+ // We should get a resolution event after 500ms
+ // (delta) since we wait for 500ms less than the
+ // expected resolution time above to make sure
+ // that the address did not yet resolve. Waiting
+ // for 1s (2x delta) without a resolution event
+ // means something is wrong.
+ t.Fatal("timed out waiting for DAD resolution")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicID != 1 {
+ t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID)
+ }
+ if e.addr != addr1 {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
+ }
+ if !e.resolved {
+ t.Fatal("got DAD event w/ resolved = false, want = true")
+ }
+ }
addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
if err != nil {
t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
@@ -172,7 +381,8 @@ func TestDADResolve(t *testing.T) {
}
// Check NDP packet.
- checker.IPv6(t, p.Header.ToVectorisedView().First(),
+ checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(),
+ checker.TTL(header.NDPHopLimit),
checker.NDPNS(
checker.NDPNSTargetAddress(addr1)))
}
@@ -250,13 +460,18 @@ func TestDADFail(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ ndpConfigs := stack.DefaultNDPConfigurations()
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.DefaultNDPConfigurations(),
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
}
opts.NDPConfigs.RetransmitTimer = time.Second * 2
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(0, 1280, linkAddr1)
s := stack.New(opts)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
@@ -279,15 +494,37 @@ func TestDADFail(t *testing.T) {
// Receive a packet to simulate multiple nodes owning or
// attempting to own the same address.
hdr := test.makeBuf(addr1)
- e.Inject(header.IPv6ProtocolNumber, hdr.View().ToVectorisedView())
+ e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
stat := test.getStat(s.Stats().ICMP.V6PacketsReceived)
if got := stat.Value(); got != 1 {
t.Fatalf("got stat = %d, want = 1", got)
}
- // Wait 3 seconds to make sure that DAD did not resolve
- time.Sleep(3 * time.Second)
+ // Wait for DAD to fail and make sure the address did
+ // not get resolved.
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // If we don't get a failure event after the
+ // expected resolution time + extra 1s buffer,
+ // something is wrong.
+ t.Fatal("timed out waiting for DAD failure")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicID != 1 {
+ t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID)
+ }
+ if e.addr != addr1 {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
+ }
+ if e.resolved {
+ t.Fatal("got DAD event w/ resolved = true, want = false")
+ }
+ }
addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
if err != nil {
t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
@@ -302,13 +539,20 @@ func TestDADFail(t *testing.T) {
// TestDADStop tests to make sure that the DAD process stops when an address is
// removed.
func TestDADStop(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ ndpConfigs := stack.NDPConfigurations{
+ RetransmitTimer: time.Second,
+ DupAddrDetectTransmits: 2,
+ }
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
}
- opts.NDPConfigs.RetransmitTimer = time.Second
- opts.NDPConfigs.DupAddrDetectTransmits = 2
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(0, 1280, linkAddr1)
s := stack.New(opts)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
@@ -332,11 +576,27 @@ func TestDADStop(t *testing.T) {
t.Fatalf("RemoveAddress(_, %s) = %s", addr1, err)
}
- // Wait for the time to normally resolve
- // DupAddrDetectTransmits(2) * RetransmitTimer(1s) = 2s.
- // An extra 250ms is added to make sure that if DAD was still running
- // it resolves and the check below fails.
- time.Sleep(2*time.Second + 250*time.Millisecond)
+ // Wait for DAD to fail (since the address was removed during DAD).
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // If we don't get a failure event after the expected resolution
+ // time + extra 1s buffer, something is wrong.
+ t.Fatal("timed out waiting for DAD failure")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicID != 1 {
+ t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID)
+ }
+ if e.addr != addr1 {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
+ }
+ if e.resolved {
+ t.Fatal("got DAD event w/ resolved = true, want = false")
+ }
+
+ }
addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
if err != nil {
t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
@@ -350,3 +610,1448 @@ func TestDADStop(t *testing.T) {
t.Fatalf("got NeighborSolicit = %d, want <= 1", got)
}
}
+
+// TestSetNDPConfigurationFailsForBadNICID tests to make sure we get an error if
+// we attempt to update NDP configurations using an invalid NICID.
+func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ })
+
+ // No NIC with ID 1 yet.
+ if got := s.SetNDPConfigurations(1, stack.NDPConfigurations{}); got != tcpip.ErrUnknownNICID {
+ t.Fatalf("got s.SetNDPConfigurations = %v, want = %s", got, tcpip.ErrUnknownNICID)
+ }
+}
+
+// TestSetNDPConfigurations tests that we can update and use per-interface NDP
+// configurations without affecting the default NDP configurations or other
+// interfaces' configurations.
+func TestSetNDPConfigurations(t *testing.T) {
+ tests := []struct {
+ name string
+ dupAddrDetectTransmits uint8
+ retransmitTimer time.Duration
+ expectedRetransmitTimer time.Duration
+ }{
+ {
+ "OK",
+ 1,
+ time.Second,
+ time.Second,
+ },
+ {
+ "Invalid Retransmit Timer",
+ 1,
+ 0,
+ time.Second,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
+ })
+
+ // This NIC(1)'s NDP configurations will be updated to
+ // be different from the default.
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Created before updating NIC(1)'s NDP configurations
+ // but updating NIC(1)'s NDP configurations should not
+ // affect other existing NICs.
+ if err := s.CreateNIC(2, e); err != nil {
+ t.Fatalf("CreateNIC(2) = %s", err)
+ }
+
+ // Update the NDP configurations on NIC(1) to use DAD.
+ configs := stack.NDPConfigurations{
+ DupAddrDetectTransmits: test.dupAddrDetectTransmits,
+ RetransmitTimer: test.retransmitTimer,
+ }
+ if err := s.SetNDPConfigurations(1, configs); err != nil {
+ t.Fatalf("got SetNDPConfigurations(1, _) = %s", err)
+ }
+
+ // Created after updating NIC(1)'s NDP configurations
+ // but the stack's default NDP configurations should not
+ // have been updated.
+ if err := s.CreateNIC(3, e); err != nil {
+ t.Fatalf("CreateNIC(3) = %s", err)
+ }
+
+ // Add addresses for each NIC.
+ if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(1, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
+ }
+ if err := s.AddAddress(2, header.IPv6ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(2, %d, %s) = %s", header.IPv6ProtocolNumber, addr2, err)
+ }
+ if err := s.AddAddress(3, header.IPv6ProtocolNumber, addr3); err != nil {
+ t.Fatalf("AddAddress(3, %d, %s) = %s", header.IPv6ProtocolNumber, addr3, err)
+ }
+
+ // Address should not be considered bound to NIC(1) yet
+ // (DAD ongoing).
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ // Should get the address on NIC(2) and NIC(3)
+ // immediately since we should not have performed DAD on
+ // it as the stack was configured to not do DAD by
+ // default and we only updated the NDP configurations on
+ // NIC(1).
+ addr, err = s.GetMainNICAddress(2, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(2, _) err = %s", err)
+ }
+ if addr.Address != addr2 {
+ t.Fatalf("got stack.GetMainNICAddress(2, _) = %s, want = %s", addr, addr2)
+ }
+ addr, err = s.GetMainNICAddress(3, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(3, _) err = %s", err)
+ }
+ if addr.Address != addr3 {
+ t.Fatalf("got stack.GetMainNICAddress(3, _) = %s, want = %s", addr, addr3)
+ }
+
+ // Sleep until right (500ms before) before resolution to
+ // make sure the address didn't resolve on NIC(1) yet.
+ const delta = 500 * time.Millisecond
+ time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta)
+ addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ // Wait for DAD to resolve.
+ select {
+ case <-time.After(2 * delta):
+ // We should get a resolution event after 500ms
+ // (delta) since we wait for 500ms less than the
+ // expected resolution time above to make sure
+ // that the address did not yet resolve. Waiting
+ // for 1s (2x delta) without a resolution event
+ // means something is wrong.
+ t.Fatal("timed out waiting for DAD resolution")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicID != 1 {
+ t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID)
+ }
+ if e.addr != addr1 {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
+ }
+ if !e.resolved {
+ t.Fatal("got DAD event w/ resolved = false, want = true")
+ }
+ }
+ addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(1, _) err = %s", err)
+ }
+ if addr.Address != addr1 {
+ t.Fatalf("got stack.GetMainNICAddress(1, _) = %s, want = %s", addr, addr1)
+ }
+ })
+ }
+}
+
+// raBufWithOpts returns a valid NDP Router Advertisement with options.
+//
+// Note, raBufWithOpts does not populate any of the RA fields other than the
+// Router Lifetime.
+func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer {
+ icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length())
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
+ pkt := header.ICMPv6(hdr.Prepend(icmpSize))
+ pkt.SetType(header.ICMPv6RouterAdvert)
+ pkt.SetCode(0)
+ ra := header.NDPRouterAdvert(pkt.NDPPayload())
+ opts := ra.Options()
+ opts.Serialize(optSer)
+ // Populate the Router Lifetime.
+ binary.BigEndian.PutUint16(pkt.NDPPayload()[2:], rl)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ iph.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: ip,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
+
+ return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()}
+}
+
+// raBuf returns a valid NDP Router Advertisement.
+//
+// Note, raBuf does not populate any of the RA fields other than the
+// Router Lifetime.
+func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer {
+ return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{})
+}
+
+// raBufWithPI returns a valid NDP Router Advertisement with a single Prefix
+// Information option.
+//
+// Note, raBufWithPI does not populate any of the RA fields other than the
+// Router Lifetime.
+func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) tcpip.PacketBuffer {
+ flags := uint8(0)
+ if onLink {
+ // The OnLink flag is the 7th bit in the flags byte.
+ flags |= 1 << 7
+ }
+ if auto {
+ // The Address Auto-Configuration flag is the 6th bit in the
+ // flags byte.
+ flags |= 1 << 6
+ }
+
+ // A valid header.NDPPrefixInformation must be 30 bytes.
+ buf := [30]byte{}
+ // The first byte in a header.NDPPrefixInformation is the Prefix Length
+ // field.
+ buf[0] = uint8(prefix.PrefixLen)
+ // The 2nd byte within a header.NDPPrefixInformation is the Flags field.
+ buf[1] = flags
+ // The Valid Lifetime field starts after the 2nd byte within a
+ // header.NDPPrefixInformation.
+ binary.BigEndian.PutUint32(buf[2:], vl)
+ // The Preferred Lifetime field starts after the 6th byte within a
+ // header.NDPPrefixInformation.
+ binary.BigEndian.PutUint32(buf[6:], pl)
+ // The Prefix Address field starts after the 14th byte within a
+ // header.NDPPrefixInformation.
+ copy(buf[14:], prefix.Address)
+ return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{
+ header.NDPPrefixInformation(buf[:]),
+ })
+}
+
+// TestNoRouterDiscovery tests that router discovery will not be performed if
+// configured not to.
+func TestNoRouterDiscovery(t *testing.T) {
+ // Being configured to discover routers means handle and
+ // discover are set to true and forwarding is set to false.
+ // This tests all possible combinations of the configurations,
+ // except for the configuration where handle = true, discover =
+ // true and forwarding = false (the required configuration to do
+ // router discovery) - that will done in other tests.
+ for i := 0; i < 7; i++ {
+ handle := i&1 != 0
+ discover := i&2 != 0
+ forwarding := i&4 == 0
+
+ t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: handle,
+ DiscoverDefaultRouters: discover,
+ },
+ NDPDisp: &ndpDisp,
+ })
+ s.SetForwarding(forwarding)
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Rx an RA with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("unexpectedly discovered a router when configured not to")
+ default:
+ }
+ })
+ }
+}
+
+// Check e to make sure that the event is for addr on nic with ID 1, and the
+// discovered flag set to discovered.
+func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string {
+ return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e))
+}
+
+// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not
+// remember a discovered router when the dispatcher asks it not to.
+func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Receive an RA for a router we should not remember.
+ const lifetimeSeconds = 1
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds))
+ select {
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, llAddr2, true); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected router discovery event")
+ }
+
+ // Wait for the invalidation time plus some buffer to make sure we do
+ // not actually receive any invalidation events as we should not have
+ // remembered the router in the first place.
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("should not have received any router events")
+ case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
+ }
+}
+
+func TestRouterDiscovery(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 1),
+ rememberRouter: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ expectRouterEvent := func(addr tcpip.Address, discovered bool) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, addr, discovered); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected router discovery event")
+ }
+ }
+
+ expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, addr, false); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for router discovery event")
+ }
+ }
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Rx an RA from lladdr2 with zero lifetime. It should not be
+ // remembered.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("unexpectedly discovered a router with 0 lifetime")
+ default:
+ }
+
+ // Rx an RA from lladdr2 with a huge lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
+ expectRouterEvent(llAddr2, true)
+
+ // Rx an RA from another router (lladdr3) with non-zero lifetime.
+ l3Lifetime := time.Duration(6)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, uint16(l3Lifetime)))
+ expectRouterEvent(llAddr3, true)
+
+ // Rx an RA from lladdr2 with lesser lifetime.
+ l2Lifetime := time.Duration(2)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(l2Lifetime)))
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("Should not receive a router event when updating lifetimes for known routers")
+ default:
+ }
+
+ // Wait for lladdr2's router invalidation timer to fire. The lifetime
+ // of the router should have been updated to the most recent (smaller)
+ // lifetime.
+ //
+ // Wait for the normal lifetime plus an extra bit for the
+ // router to get invalidated. If we don't get an invalidation
+ // event after this time, then something is wrong.
+ expectAsyncRouterInvalidationEvent(llAddr2, l2Lifetime*time.Second+defaultTimeout)
+
+ // Rx an RA from lladdr2 with huge lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
+ expectRouterEvent(llAddr2, true)
+
+ // Rx an RA from lladdr2 with zero lifetime. It should be invalidated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
+ expectRouterEvent(llAddr2, false)
+
+ // Wait for lladdr3's router invalidation timer to fire. The lifetime
+ // of the router should have been updated to the most recent (smaller)
+ // lifetime.
+ //
+ // Wait for the normal lifetime plus an extra bit for the
+ // router to get invalidated. If we don't get an invalidation
+ // event after this time, then something is wrong.
+ expectAsyncRouterInvalidationEvent(llAddr3, l3Lifetime*time.Second+defaultTimeout)
+}
+
+// TestRouterDiscoveryMaxRouters tests that only
+// stack.MaxDiscoveredDefaultRouters discovered routers are remembered.
+func TestRouterDiscoveryMaxRouters(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 1),
+ rememberRouter: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Receive an RA from 2 more than the max number of discovered routers.
+ for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ {
+ linkAddr := []byte{2, 2, 3, 4, 5, 0}
+ linkAddr[5] = byte(i)
+ llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr))
+
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5))
+
+ if i <= stack.MaxDiscoveredDefaultRouters {
+ select {
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, llAddr, true); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected router discovery event")
+ }
+
+ } else {
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("should not have discovered a new router after we already discovered the max number of routers")
+ default:
+ }
+ }
+ }
+}
+
+// TestNoPrefixDiscovery tests that prefix discovery will not be performed if
+// configured not to.
+func TestNoPrefixDiscovery(t *testing.T) {
+ prefix := tcpip.AddressWithPrefix{
+ Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"),
+ PrefixLen: 64,
+ }
+
+ // Being configured to discover prefixes means handle and
+ // discover are set to true and forwarding is set to false.
+ // This tests all possible combinations of the configurations,
+ // except for the configuration where handle = true, discover =
+ // true and forwarding = false (the required configuration to do
+ // prefix discovery) - that will done in other tests.
+ for i := 0; i < 7; i++ {
+ handle := i&1 != 0
+ discover := i&2 != 0
+ forwarding := i&4 == 0
+
+ t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ prefixC: make(chan ndpPrefixEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: handle,
+ DiscoverOnLinkPrefixes: discover,
+ },
+ NDPDisp: &ndpDisp,
+ })
+ s.SetForwarding(forwarding)
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Rx an RA with prefix with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0))
+
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly discovered a prefix when configured not to")
+ default:
+ }
+ })
+ }
+}
+
+// Check e to make sure that the event is for prefix on nic with ID 1, and the
+// discovered flag set to discovered.
+func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string {
+ return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e))
+}
+
+// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not
+// remember a discovered on-link prefix when the dispatcher asks it not to.
+func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
+ t.Parallel()
+
+ prefix, subnet, _ := prefixSubnetAddr(0, "")
+
+ ndpDisp := ndpDispatcher{
+ prefixC: make(chan ndpPrefixEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: false,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Receive an RA with prefix that we should not remember.
+ const lifetimeSeconds = 1
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0))
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, subnet, true); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected prefix discovery event")
+ }
+
+ // Wait for the invalidation time plus some buffer to make sure we do
+ // not actually receive any invalidation events as we should not have
+ // remembered the prefix in the first place.
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("should not have received any prefix events")
+ case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
+ }
+}
+
+func TestPrefixDiscovery(t *testing.T) {
+ t.Parallel()
+
+ prefix1, subnet1, _ := prefixSubnetAddr(0, "")
+ prefix2, subnet2, _ := prefixSubnetAddr(1, "")
+ prefix3, subnet3, _ := prefixSubnetAddr(2, "")
+
+ ndpDisp := ndpDispatcher{
+ prefixC: make(chan ndpPrefixEvent, 1),
+ rememberPrefix: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, prefix, discovered); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected prefix discovery event")
+ }
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly discovered a prefix with 0 lifetime")
+ default:
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0))
+ expectPrefixEvent(subnet1, true)
+
+ // Receive an RA with prefix2 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0))
+ expectPrefixEvent(subnet2, true)
+
+ // Receive an RA with prefix3 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0))
+ expectPrefixEvent(subnet3, true)
+
+ // Receive an RA with prefix1 in a PI with lifetime = 0.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
+ expectPrefixEvent(subnet1, false)
+
+ // Receive an RA with prefix2 in a PI with lesser lifetime.
+ lifetime := uint32(2)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0))
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly received prefix event when updating lifetime")
+ default:
+ }
+
+ // Wait for prefix2's most recent invalidation timer plus some buffer to
+ // expire.
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, subnet2, false); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(time.Duration(lifetime)*time.Second + defaultTimeout):
+ t.Fatal("timed out waiting for prefix discovery event")
+ }
+
+ // Receive RA to invalidate prefix3.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0))
+ expectPrefixEvent(subnet3, false)
+}
+
+func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
+ // Update the infinite lifetime value to a smaller value so we can test
+ // that when we receive a PI with such a lifetime value, we do not
+ // invalidate the prefix.
+ const testInfiniteLifetimeSeconds = 2
+ const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second
+ saved := header.NDPInfiniteLifetime
+ header.NDPInfiniteLifetime = testInfiniteLifetime
+ defer func() {
+ header.NDPInfiniteLifetime = saved
+ }()
+
+ prefix := tcpip.AddressWithPrefix{
+ Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"),
+ PrefixLen: 64,
+ }
+ subnet := prefix.Subnet()
+
+ ndpDisp := ndpDispatcher{
+ prefixC: make(chan ndpPrefixEvent, 1),
+ rememberPrefix: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, prefix, discovered); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected prefix discovery event")
+ }
+ }
+
+ // Receive an RA with prefix in an NDP Prefix Information option (PI)
+ // with infinite valid lifetime which should not get invalidated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0))
+ expectPrefixEvent(subnet, true)
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
+ case <-time.After(testInfiniteLifetime + defaultTimeout):
+ }
+
+ // Receive an RA with finite lifetime.
+ // The prefix should get invalidated after 1s.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0))
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, subnet, false); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(testInfiniteLifetime):
+ t.Fatal("timed out waiting for prefix discovery event")
+ }
+
+ // Receive an RA with finite lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0))
+ expectPrefixEvent(subnet, true)
+
+ // Receive an RA with prefix with an infinite lifetime.
+ // The prefix should not be invalidated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0))
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
+ case <-time.After(testInfiniteLifetime + defaultTimeout):
+ }
+
+ // Receive an RA with a prefix with a lifetime value greater than the
+ // set infinite lifetime value.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0))
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
+ case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultTimeout):
+ }
+
+ // Receive an RA with 0 lifetime.
+ // The prefix should get invalidated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 0, 0))
+ expectPrefixEvent(subnet, false)
+}
+
+// TestPrefixDiscoveryMaxRouters tests that only
+// stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered.
+func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3),
+ rememberPrefix: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: false,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ optSer := make(header.NDPOptionsSerializer, stack.MaxDiscoveredOnLinkPrefixes+2)
+ prefixes := [stack.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{}
+
+ // Receive an RA with 2 more than the max number of discovered on-link
+ // prefixes.
+ for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ {
+ prefixAddr := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0}
+ prefixAddr[7] = byte(i)
+ prefix := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(prefixAddr[:]),
+ PrefixLen: 64,
+ }
+ prefixes[i] = prefix.Subnet()
+ buf := [30]byte{}
+ buf[0] = uint8(prefix.PrefixLen)
+ buf[1] = 128
+ binary.BigEndian.PutUint32(buf[2:], 10)
+ copy(buf[14:], prefix.Address)
+
+ optSer[i] = header.NDPPrefixInformation(buf[:])
+ }
+
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer))
+ for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ {
+ if i < stack.MaxDiscoveredOnLinkPrefixes {
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected prefix discovery event")
+ }
+ } else {
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("should not have discovered a new prefix after we already discovered the max number of prefixes")
+ default:
+ }
+ }
+ }
+}
+
+// Checks to see if list contains an IPv6 address, item.
+func contains(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool {
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: item,
+ }
+
+ for _, i := range list {
+ if i == protocolAddress {
+ return true
+ }
+ }
+
+ return false
+}
+
+// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to.
+func TestNoAutoGenAddr(t *testing.T) {
+ prefix, _, _ := prefixSubnetAddr(0, "")
+
+ // Being configured to auto-generate addresses means handle and
+ // autogen are set to true and forwarding is set to false.
+ // This tests all possible combinations of the configurations,
+ // except for the configuration where handle = true, autogen =
+ // true and forwarding = false (the required configuration to do
+ // SLAAC) - that will done in other tests.
+ for i := 0; i < 7; i++ {
+ handle := i&1 != 0
+ autogen := i&2 != 0
+ forwarding := i&4 == 0
+
+ t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: handle,
+ AutoGenGlobalAddresses: autogen,
+ },
+ NDPDisp: &ndpDisp,
+ })
+ s.SetForwarding(forwarding)
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Rx an RA with prefix with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0))
+
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address when configured not to")
+ default:
+ }
+ })
+ }
+}
+
+// Check e to make sure that the event is for addr on nic with ID 1, and the
+// event type is set to eventType.
+func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string {
+ return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e))
+}
+
+// TestAutoGenAddr tests that an address is properly generated and invalidated
+// when configured to do so.
+func TestAutoGenAddr(t *testing.T) {
+ const newMinVL = 2
+ newMinVLDuration := newMinVL * time.Second
+ saved := stack.MinPrefixInformationValidLifetimeForUpdate
+ defer func() {
+ stack.MinPrefixInformationValidLifetimeForUpdate = saved
+ }()
+ stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
+
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address with 0 lifetime")
+ default:
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
+
+ // Receive an RA with prefix2 in an NDP Prefix Information option (PI)
+ // with preferred lifetime > valid lifetime
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime")
+ default:
+ }
+
+ // Receive an RA with prefix2 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
+ if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) {
+ t.Fatalf("Should have %s in the list of addresses", addr2)
+ }
+
+ // Refresh valid lifetime for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix")
+ default:
+ }
+
+ // Wait for addr of prefix1 to be invalidated.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(newMinVLDuration + defaultTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ if contains(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ t.Fatalf("Should not have %s in the list of addresses", addr1)
+ }
+ if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) {
+ t.Fatalf("Should have %s in the list of addresses", addr2)
+ }
+}
+
+// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an
+// auto-generated address only gets updated when required to, as specified in
+// RFC 4862 section 5.5.3.e.
+func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
+ const infiniteVL = 4294967295
+ const newMinVL = 5
+ saved := stack.MinPrefixInformationValidLifetimeForUpdate
+ defer func() {
+ stack.MinPrefixInformationValidLifetimeForUpdate = saved
+ }()
+ stack.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second
+
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+
+ tests := []struct {
+ name string
+ ovl uint32
+ nvl uint32
+ evl uint32
+ }{
+ // Should update the VL to the minimum VL for updating if the
+ // new VL is less than newMinVL but was originally greater than
+ // it.
+ {
+ "LargeVLToVLLessThanMinVLForUpdate",
+ 9999,
+ 1,
+ newMinVL,
+ },
+ {
+ "LargeVLTo0",
+ 9999,
+ 0,
+ newMinVL,
+ },
+ {
+ "InfiniteVLToVLLessThanMinVLForUpdate",
+ infiniteVL,
+ 1,
+ newMinVL,
+ },
+ {
+ "InfiniteVLTo0",
+ infiniteVL,
+ 0,
+ newMinVL,
+ },
+
+ // Should not update VL if original VL was less than newMinVL
+ // and the new VL is also less than newMinVL.
+ {
+ "ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate",
+ newMinVL - 1,
+ newMinVL - 3,
+ newMinVL - 1,
+ },
+
+ // Should take the new VL if the new VL is greater than the
+ // remaining time or is greater than newMinVL.
+ {
+ "MorethanMinVLToLesserButStillMoreThanMinVLForUpdate",
+ newMinVL + 5,
+ newMinVL + 3,
+ newMinVL + 3,
+ },
+ {
+ "SmallVLToGreaterVLButStillLessThanMinVLForUpdate",
+ newMinVL - 3,
+ newMinVL - 1,
+ newMinVL - 1,
+ },
+ {
+ "SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate",
+ newMinVL - 3,
+ newMinVL + 1,
+ newMinVL + 1,
+ },
+ }
+
+ const delta = 500 * time.Millisecond
+
+ // This Run will not return until the parallel tests finish.
+ //
+ // We need this because we need to do some teardown work after the
+ // parallel tests complete.
+ //
+ // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
+ // more details.
+ t.Run("group", func(t *testing.T) {
+ for _, test := range tests {
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10),
+ }
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Receive an RA with prefix with initial VL,
+ // test.ovl.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+
+ // Receive an new RA with prefix with new VL,
+ // test.nvl.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0))
+
+ //
+ // Validate that the VL for the address got set
+ // to test.evl.
+ //
+
+ // Make sure we do not get any invalidation
+ // events until atleast 500ms (delta) before
+ // test.evl.
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly received an auto gen addr event")
+ case <-time.After(time.Duration(test.evl)*time.Second - delta):
+ }
+
+ // Wait for another second (2x delta), but now
+ // we expect the invalidation event.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+
+ case <-time.After(2 * delta):
+ t.Fatal("timeout waiting for addr auto gen event")
+ }
+ })
+ }
+ })
+}
+
+// TestAutoGenAddrRemoval tests that when auto-generated addresses are removed
+// by the user, its resources will be cleaned up and an invalidation event will
+// be sent to the integrator.
+func TestAutoGenAddrRemoval(t *testing.T) {
+ t.Parallel()
+
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ // Receive a PI to auto-generate an address.
+ const lifetimeSeconds = 1
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0))
+ expectAutoGenAddrEvent(addr, newAddr)
+
+ // Removing the address should result in an invalidation event
+ // immediately.
+ if err := s.RemoveAddress(1, addr.Address); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr.Address, err)
+ }
+ expectAutoGenAddrEvent(addr, invalidatedAddr)
+
+ // Wait for the original valid lifetime to make sure the original timer
+ // got stopped/cleaned up.
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly received an auto gen addr event")
+ case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
+ }
+}
+
+// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that
+// is already assigned to the NIC, the static address remains.
+func TestAutoGenAddrStaticConflict(t *testing.T) {
+ t.Parallel()
+
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Add the address as a static address before SLAAC tries to add it.
+ if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err)
+ }
+ if !contains(s.NICInfo()[1].ProtocolAddresses, addr) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
+
+ // Receive a PI where the generated address will be the same as the one
+ // that we already have assigned statically.
+ const lifetimeSeconds = 1
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically")
+ default:
+ }
+ if !contains(s.NICInfo()[1].ProtocolAddresses, addr) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
+
+ // Should not get an invalidation event after the PI's invalidation
+ // time.
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly received an auto gen addr event")
+ case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
+ }
+ if !contains(s.NICInfo()[1].ProtocolAddresses, addr) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
+}
+
+// TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event
+// to the integrator when an RA is received with the NDP Recursive DNS Server
+// option with at least one valid address.
+func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ opt header.NDPRecursiveDNSServer
+ expected *ndpRDNSS
+ }{
+ {
+ "Unspecified",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 2,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ }),
+ nil,
+ },
+ {
+ "Multicast",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 2,
+ 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
+ }),
+ nil,
+ },
+ {
+ "OptionTooSmall",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 2,
+ 1, 2, 3, 4, 5, 6, 7, 8,
+ }),
+ nil,
+ },
+ {
+ "0Addresses",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 2,
+ }),
+ nil,
+ },
+ {
+ "Valid1Address",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 2,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1,
+ }),
+ &ndpRDNSS{
+ []tcpip.Address{
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01",
+ },
+ 2 * time.Second,
+ },
+ },
+ {
+ "Valid2Addresses",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2,
+ }),
+ &ndpRDNSS{
+ []tcpip.Address{
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01",
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02",
+ },
+ time.Second,
+ },
+ },
+ {
+ "Valid3Addresses",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 3,
+ }),
+ &ndpRDNSS{
+ []tcpip.Address{
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01",
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02",
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x03",
+ },
+ 0,
+ },
+ },
+ }
+
+ for _, test := range tests {
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ // We do not expect more than a single RDNSS
+ // event at any time for this test.
+ rdnssC: make(chan ndpRDNSSEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, header.NDPOptionsSerializer{test.opt}))
+
+ if test.expected != nil {
+ select {
+ case e := <-ndpDisp.rdnssC:
+ if e.nicID != 1 {
+ t.Errorf("got rdnss nicID = %d, want = 1", e.nicID)
+ }
+ if diff := cmp.Diff(e.rdnss.addrs, test.expected.addrs); diff != "" {
+ t.Errorf("rdnss addrs mismatch (-want +got):\n%s", diff)
+ }
+ if e.rdnss.lifetime != test.expected.lifetime {
+ t.Errorf("got rdnss lifetime = %s, want = %s", e.rdnss.lifetime, test.expected.lifetime)
+ }
+ default:
+ t.Fatal("expected an RDNSS option event")
+ }
+ }
+
+ // Should have no more RDNSS options.
+ select {
+ case e := <-ndpDisp.rdnssC:
+ t.Fatalf("unexpectedly got a new RDNSS option event: %+v", e)
+ default:
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 43f4ad91e..e8401c673 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -19,7 +19,6 @@ import (
"sync"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/ilist"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -37,13 +36,20 @@ type NIC struct {
mu sync.RWMutex
spoofing bool
promiscuous bool
- primary map[tcpip.NetworkProtocolNumber]*ilist.List
+ primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint
endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
addressRanges []tcpip.Subnet
mcastJoins map[NetworkEndpointID]int32
+ // packetEPs is protected by mu, but the contained PacketEndpoint
+ // values are not.
+ packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
stats NICStats
+ // ndp is the NDP related state for NIC.
+ //
+ // Note, read and write operations on ndp require that the NIC is
+ // appropriately locked.
ndp ndpState
}
@@ -78,16 +84,26 @@ const (
NeverPrimaryEndpoint
)
+// newNIC returns a new NIC using the default NDP configurations from stack.
func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC {
- return &NIC{
+ // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For
+ // example, make sure that the link address it provides is a valid
+ // unicast ethernet address.
+
+ // TODO(b/143357959): RFC 8200 section 5 requires that IPv6 endpoints
+ // observe an MTU of at least 1280 bytes. Ensure that this requirement
+ // of IPv6 is supported on this endpoint's LinkEndpoint.
+
+ nic := &NIC{
stack: stack,
id: id,
name: name,
linkEP: ep,
loopback: loopback,
- primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
+ primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint),
endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
mcastJoins: make(map[NetworkEndpointID]int32),
+ packetEPs: make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint),
stats: NICStats{
Tx: DirectionStats{
Packets: &tcpip.StatCounter{},
@@ -99,9 +115,24 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback
},
},
ndp: ndpState{
- dad: make(map[tcpip.Address]dadState),
+ configs: stack.ndpConfigs,
+ dad: make(map[tcpip.Address]dadState),
+ defaultRouters: make(map[tcpip.Address]defaultRouterState),
+ onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
+ autoGenAddresses: make(map[tcpip.Address]autoGenAddressState),
},
}
+ nic.ndp.nic = nic
+
+ // Register supported packet endpoint protocols.
+ for _, netProto := range header.Ethertypes {
+ nic.packetEPs[netProto] = []PacketEndpoint{}
+ }
+ for _, netProto := range stack.networkProtocols {
+ nic.packetEPs[netProto.Number()] = []PacketEndpoint{}
+ }
+
+ return nic
}
// enable enables the NIC. enable will attach the link to its LinkEndpoint and
@@ -126,11 +157,50 @@ func (n *NIC) enable() *tcpip.Error {
// when we perform Duplicate Address Detection, or Router Advertisement
// when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861
// section 4.2 for more information.
- if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok {
- return n.joinGroup(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress)
+ //
+ // Also auto-generate an IPv6 link-local address based on the NIC's
+ // link address if it is configured to do so. Note, each interface is
+ // required to have IPv6 link-local unicast address, as per RFC 4291
+ // section 2.1.
+ _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]
+ if !ok {
+ return nil
}
- return nil
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil {
+ return err
+ }
+
+ if !n.stack.autoGenIPv6LinkLocal {
+ return nil
+ }
+
+ l2addr := n.linkEP.LinkAddress()
+
+ // Only attempt to generate the link-local address if we have a
+ // valid MAC address.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address
+ // (provided by LinkEndpoint.LinkAddress) before reaching this
+ // point.
+ if !header.IsValidUnicastEthernetAddress(l2addr) {
+ return nil
+ }
+
+ addr := header.LinkLocalAddr(l2addr)
+
+ _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr,
+ PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen,
+ },
+ }, CanBePrimaryEndpoint)
+
+ return err
}
// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
@@ -166,18 +236,7 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN
n.mu.RLock()
defer n.mu.RUnlock()
- list := n.primary[protocol]
- if list == nil {
- return nil
- }
-
- for e := list.Front(); e != nil; e = e.Next() {
- r := e.(*referencedNetworkEndpoint)
- // TODO(crawshaw): allow broadcast address when SO_BROADCAST is set.
- switch r.ep.ID().LocalAddress {
- case header.IPv4Broadcast, header.IPv4Any:
- continue
- }
+ for _, r := range n.primary[protocol] {
if r.isValidForOutgoing() && r.tryIncRef() {
return r
}
@@ -186,6 +245,20 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN
return nil
}
+// hasPermanentAddrLocked returns true if n has a permanent (including currently
+// tentative) address, addr.
+func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool {
+ ref, ok := n.endpoints[NetworkEndpointID{addr}]
+
+ if !ok {
+ return false
+ }
+
+ kind := ref.getKind()
+
+ return kind == permanent || kind == permanentTentative
+}
+
func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous)
}
@@ -277,7 +350,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
Address: address,
PrefixLen: netProto.DefaultPrefixLen(),
},
- }, peb, temporary)
+ }, peb, temporary, static)
n.mu.Unlock()
return ref
@@ -291,9 +364,31 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p
// The NIC already have a permanent endpoint with that address.
return nil, tcpip.ErrDuplicateAddress
case permanentExpired, temporary:
- // Promote the endpoint to become permanent.
+ // Promote the endpoint to become permanent and respect
+ // the new peb.
if ref.tryIncRef() {
ref.setKind(permanent)
+
+ refs := n.primary[ref.protocol]
+ for i, r := range refs {
+ if r == ref {
+ switch peb {
+ case CanBePrimaryEndpoint:
+ return ref, nil
+ case FirstPrimaryEndpoint:
+ if i == 0 {
+ return ref, nil
+ }
+ n.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ case NeverPrimaryEndpoint:
+ n.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ return ref, nil
+ }
+ }
+ }
+
+ n.insertPrimaryEndpointLocked(ref, peb)
+
return ref, nil
}
// tryIncRef failing means the endpoint is scheduled to be removed once
@@ -304,10 +399,10 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p
}
}
- return n.addAddressLocked(protocolAddress, peb, permanent)
+ return n.addAddressLocked(protocolAddress, peb, permanent, static)
}
-func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) {
+func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType) (*referencedNetworkEndpoint, *tcpip.Error) {
// TODO(b/141022673): Validate IP address before adding them.
// Sanity check.
@@ -337,11 +432,12 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
}
ref := &referencedNetworkEndpoint{
- refs: 1,
- ep: ep,
- nic: n,
- protocol: protocolAddress.Protocol,
- kind: kind,
+ refs: 1,
+ ep: ep,
+ nic: n,
+ protocol: protocolAddress.Protocol,
+ kind: kind,
+ configType: configType,
}
// Set up cache if link address resolution exists for this protocol.
@@ -362,22 +458,11 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
n.endpoints[id] = ref
- l, ok := n.primary[protocolAddress.Protocol]
- if !ok {
- l = &ilist.List{}
- n.primary[protocolAddress.Protocol] = l
- }
-
- switch peb {
- case CanBePrimaryEndpoint:
- l.PushBack(ref)
- case FirstPrimaryEndpoint:
- l.PushFront(ref)
- }
+ n.insertPrimaryEndpointLocked(ref, peb)
// If we are adding a tentative IPv6 address, start DAD.
if isIPv6Unicast && kind == permanentTentative {
- if err := n.ndp.startDuplicateAddressDetection(n, protocolAddress.AddressWithPrefix.Address, ref); err != nil {
+ if err := n.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil {
return nil, err
}
}
@@ -430,8 +515,7 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
var addrs []tcpip.ProtocolAddress
for proto, list := range n.primary {
- for e := list.Front(); e != nil; e = e.Next() {
- ref := e.(*referencedNetworkEndpoint)
+ for _, ref := range list {
// Don't include tentative, expired or tempory endpoints
// to avoid confusion and prevent the caller from using
// those.
@@ -496,6 +580,19 @@ func (n *NIC) AddressRanges() []tcpip.Subnet {
return append(sns, n.addressRanges...)
}
+// insertPrimaryEndpointLocked adds r to n's primary endpoint list as required
+// by peb.
+//
+// n MUST be locked.
+func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) {
+ switch peb {
+ case CanBePrimaryEndpoint:
+ n.primary[r.protocol] = append(n.primary[r.protocol], r)
+ case FirstPrimaryEndpoint:
+ n.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.primary[r.protocol]...)
+ }
+}
+
func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
id := *r.ep.ID()
@@ -513,9 +610,12 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
}
delete(n.endpoints, id)
- wasInList := r.Next() != nil || r.Prev() != nil || r == n.primary[r.protocol].Front()
- if wasInList {
- n.primary[r.protocol].Remove(r)
+ refs := n.primary[r.protocol]
+ for i, ref := range refs {
+ if ref == r {
+ n.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ break
+ }
}
r.ep.Close()
@@ -540,9 +640,18 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
isIPv6Unicast := r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr)
- // If we are removing a tentative IPv6 unicast address, stop DAD.
- if isIPv6Unicast && kind == permanentTentative {
- n.ndp.stopDuplicateAddressDetection(addr)
+ if isIPv6Unicast {
+ // If we are removing a tentative IPv6 unicast address, stop
+ // DAD.
+ if kind == permanentTentative {
+ n.ndp.stopDuplicateAddressDetection(addr)
+ }
+
+ // If we are removing an address generated via SLAAC, cleanup
+ // its SLAAC resources and notify the integrator.
+ if r.configType == slaac {
+ n.ndp.cleanupAutoGenAddrResourcesAndNotify(addr)
+ }
}
r.setKind(permanentExpired)
@@ -585,6 +694,11 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
// exists yet. Otherwise it just increments its count. n MUST be locked before
// joinGroupLocked is called.
func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ // TODO(b/143102137): When implementing MLD, make sure MLD packets are
+ // not sent unless a valid link-local address is available for use on n
+ // as an MLD packet's source address must be a link-local address as
+ // outlined in RFC 3810 section 5.
+
id := NetworkEndpointID{addr}
joins := n.mcastJoins[id]
if joins == 0 {
@@ -635,10 +749,10 @@ func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
return nil
}
-func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, vv buffer.VectorisedView) {
+func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) {
r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
r.RemoteLinkAddress = remotelinkAddr
- ref.ep.HandlePacket(&r, vv)
+ ref.ep.HandlePacket(&r, pkt)
ref.decRef()
}
@@ -648,9 +762,9 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address,
// Note that the ownership of the slice backing vv is retained by the caller.
// This rule applies only to the slice itself, not to the items of the slice;
// the ownership of the items is not retained by the caller.
-func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
n.stats.Rx.Packets.Increment()
- n.stats.Rx.Bytes.IncrementBy(uint64(vv.Size()))
+ n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
@@ -658,19 +772,39 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
return
}
+ // If no local link layer address is provided, assume it was sent
+ // directly to this NIC.
+ if local == "" {
+ local = n.linkEP.LinkAddress()
+ }
+
+ // Are any packet sockets listening for this network protocol?
+ n.mu.RLock()
+ packetEPs := n.packetEPs[protocol]
+ // Check whether there are packet sockets listening for every protocol.
+ // If we received a packet with protocol EthernetProtocolAll, then the
+ // previous for loop will have handled it.
+ if protocol != header.EthernetProtocolAll {
+ packetEPs = append(packetEPs, n.packetEPs[header.EthernetProtocolAll]...)
+ }
+ n.mu.RUnlock()
+ for _, ep := range packetEPs {
+ ep.HandlePacket(n.id, local, protocol, pkt.Clone())
+ }
+
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
n.stack.stats.IP.PacketsReceived.Increment()
}
- if len(vv.First()) < netProto.MinimumPacketSize() {
+ if len(pkt.Data.First()) < netProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
- src, dst := netProto.ParseAddresses(vv.First())
+ src, dst := netProto.ParseAddresses(pkt.Data.First())
if ref := n.getRef(protocol, dst); ref != nil {
- handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv)
+ handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, pkt)
return
}
@@ -698,31 +832,34 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
if ok {
r.RemoteAddress = src
// TODO(b/123449044): Update the source NIC as well.
- ref.ep.HandlePacket(&r, vv)
+ ref.ep.HandlePacket(&r, pkt)
ref.decRef()
} else {
// n doesn't have a destination endpoint.
// Send the packet out of n.
- hdr := buffer.NewPrependableFromView(vv.First())
- vv.RemoveFirst()
+ pkt.Header = buffer.NewPrependableFromView(pkt.Data.First())
+ pkt.Data.RemoveFirst()
// TODO(b/128629022): use route.WritePacket.
- if err := n.linkEP.WritePacket(&r, nil /* gso */, hdr, vv, protocol); err != nil {
+ if err := n.linkEP.WritePacket(&r, nil /* gso */, protocol, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
} else {
n.stats.Tx.Packets.Increment()
- n.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + vv.Size()))
+ n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size()))
}
}
return
}
- n.stack.stats.IP.InvalidAddressesReceived.Increment()
+ // If a packet socket handled the packet, don't treat it as invalid.
+ if len(packetEPs) == 0 {
+ n.stack.stats.IP.InvalidAddressesReceived.Increment()
+ }
}
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
-func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) {
+func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
@@ -734,41 +871,41 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// Raw socket packets are delivered based solely on the transport
// protocol number. We do not inspect the payload to ensure it's
// validly formed.
- n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv)
+ n.stack.demux.deliverRawPacket(r, protocol, pkt)
- if len(vv.First()) < transProto.MinimumPacketSize() {
+ if len(pkt.Data.First()) < transProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
- srcPort, dstPort, err := transProto.ParsePorts(vv.First())
+ srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First())
if err != nil {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
- if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) {
+ if n.stack.demux.deliverPacket(r, protocol, pkt, id) {
return
}
// Try to deliver to per-stack default handler.
if state.defaultHandler != nil {
- if state.defaultHandler(r, id, netHeader, vv) {
+ if state.defaultHandler(r, id, pkt) {
return
}
}
// We could not find an appropriate destination for this packet, so
// deliver it to the global handler.
- if !transProto.HandleUnknownDestinationPacket(r, id, netHeader, vv) {
+ if !transProto.HandleUnknownDestinationPacket(r, id, pkt) {
n.stack.stats.MalformedRcvdPackets.Increment()
}
}
// DeliverTransportControlPacket delivers control packets to the appropriate
// transport protocol endpoint.
-func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) {
+func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) {
state, ok := n.stack.transportProtocols[trans]
if !ok {
return
@@ -779,17 +916,17 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp
// ICMPv4 only guarantees that 8 bytes of the transport protocol will
// be present in the payload. We know that the ports are within the
// first 8 bytes for all known transport protocols.
- if len(vv.First()) < 8 {
+ if len(pkt.Data.First()) < 8 {
return
}
- srcPort, dstPort, err := transProto.ParsePorts(vv.First())
+ srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First())
if err != nil {
return
}
id := TransportEndpointID{srcPort, local, dstPort, remote}
- if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, vv, id) {
+ if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, pkt, id) {
return
}
}
@@ -838,6 +975,26 @@ func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
return n.removePermanentAddressLocked(addr)
}
+// setNDPConfigs sets the NDP configurations for n.
+//
+// Note, if c contains invalid NDP configuration values, it will be fixed to
+// use default values for the erroneous values.
+func (n *NIC) setNDPConfigs(c NDPConfigurations) {
+ c.validate()
+
+ n.mu.Lock()
+ n.ndp.configs = c
+ n.mu.Unlock()
+}
+
+// handleNDPRA handles an NDP Router Advertisement message that arrived on n.
+func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ n.ndp.handleRA(ip, ra)
+}
+
type networkEndpointKind int32
const (
@@ -857,7 +1014,7 @@ const (
// removing the permanent address from the NIC.
permanent
- // An expired permanent endoint is a permanent endoint that had its address
+ // An expired permanent endpoint is a permanent endpoint that had its address
// removed from the NIC, and it is waiting to be removed once no more routes
// hold a reference to it. This is achieved by decreasing its reference count
// by 1. If its address is re-added before the endpoint is removed, its type
@@ -873,8 +1030,50 @@ const (
temporary
)
+func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ eps, ok := n.packetEPs[netProto]
+ if !ok {
+ return tcpip.ErrNotSupported
+ }
+ n.packetEPs[netProto] = append(eps, ep)
+
+ return nil
+}
+
+func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ eps, ok := n.packetEPs[netProto]
+ if !ok {
+ return
+ }
+
+ for i, epOther := range eps {
+ if epOther == ep {
+ n.packetEPs[netProto] = append(eps[:i], eps[i+1:]...)
+ return
+ }
+ }
+}
+
+type networkEndpointConfigType int32
+
+const (
+ // A statically configured endpoint is an address that was added by
+ // some user-specified action (adding an explicit address, joining a
+ // multicast group).
+ static networkEndpointConfigType = iota
+
+ // A slaac configured endpoint is an IPv6 endpoint that was
+ // added by SLAAC as per RFC 4862 section 5.5.3.
+ slaac
+)
+
type referencedNetworkEndpoint struct {
- ilist.Entry
ep NetworkEndpoint
nic *NIC
protocol tcpip.NetworkProtocolNumber
@@ -889,6 +1088,10 @@ type referencedNetworkEndpoint struct {
// networkEndpointKind must only be accessed using {get,set}Kind().
kind networkEndpointKind
+
+ // configType is the method that was used to configure this endpoint.
+ // This must never change after the endpoint is added to a NIC.
+ configType networkEndpointConfigType
}
func (r *referencedNetworkEndpoint) getKind() networkEndpointKind {
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 9d6157f22..61fd46d66 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -60,24 +60,64 @@ const (
// TransportEndpoint is the interface that needs to be implemented by transport
// protocol (e.g., tcp, udp) endpoints that can handle packets.
type TransportEndpoint interface {
+ // UniqueID returns an unique ID for this transport endpoint.
+ UniqueID() uint64
+
// HandlePacket is called by the stack when new packets arrive to
- // this transport endpoint.
- HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView)
+ // this transport endpoint. It sets pkt.TransportHeader.
+ //
+ // HandlePacket takes ownership of pkt.
+ HandlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer)
- // HandleControlPacket is called by the stack when new control (e.g.,
+ // HandleControlPacket is called by the stack when new control (e.g.
// ICMP) packets arrive to this transport endpoint.
- HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView)
+ // HandleControlPacket takes ownership of pkt.
+ HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer)
+
+ // Close puts the endpoint in a closed state and frees all resources
+ // associated with it. This cleanup may happen asynchronously. Wait can
+ // be used to block on this asynchronous cleanup.
+ Close()
+
+ // Wait waits for any worker goroutines owned by the endpoint to stop.
+ //
+ // An endpoint can be requested to stop its worker goroutines by calling
+ // its Close method.
+ //
+ // Wait will not block if the endpoint hasn't started any goroutines
+ // yet, even if it might later.
+ Wait()
}
// RawTransportEndpoint is the interface that needs to be implemented by raw
// transport protocol endpoints. RawTransportEndpoints receive the entire
-// packet - including the link, network, and transport headers - as delivered
-// to netstack.
+// packet - including the network and transport headers - as delivered to
+// netstack.
type RawTransportEndpoint interface {
// HandlePacket is called by the stack when new packets arrive to
// this transport endpoint. The packet contains all data from the link
// layer up.
- HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView)
+ //
+ // HandlePacket takes ownership of pkt.
+ HandlePacket(r *Route, pkt tcpip.PacketBuffer)
+}
+
+// PacketEndpoint is the interface that needs to be implemented by packet
+// transport protocol endpoints. These endpoints receive link layer headers in
+// addition to whatever they contain (usually network and transport layer
+// headers and a payload).
+type PacketEndpoint interface {
+ // HandlePacket is called by the stack when new packets arrive that
+ // match the endpoint.
+ //
+ // Implementers should treat packet as immutable and should copy it
+ // before before modification.
+ //
+ // linkHeader may have a length of 0, in which case the PacketEndpoint
+ // should construct its own ethernet header for applications.
+ //
+ // HandlePacket takes ownership of pkt.
+ HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer)
}
// TransportProtocol is the interface that needs to be implemented by transport
@@ -107,7 +147,9 @@ type TransportProtocol interface {
//
// The return value indicates whether the packet was well-formed (for
// stats purposes only).
- HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool
+ //
+ // HandleUnknownDestinationPacket takes ownership of pkt.
+ HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -125,13 +167,21 @@ type TransportProtocol interface {
// the network layer.
type TransportDispatcher interface {
// DeliverTransportPacket delivers packets to the appropriate
- // transport protocol endpoint. It also returns the network layer
- // header for the enpoint to inspect or pass up the stack.
- DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView)
+ // transport protocol endpoint.
+ //
+ // pkt.NetworkHeader must be set before calling DeliverTransportPacket.
+ //
+ // DeliverTransportPacket takes ownership of pkt.
+ DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer)
// DeliverTransportControlPacket delivers control packets to the
// appropriate transport protocol endpoint.
- DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView)
+ //
+ // pkt.NetworkHeader must be set before calling
+ // DeliverTransportControlPacket.
+ //
+ // DeliverTransportControlPacket takes ownership of pkt.
+ DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer)
}
// PacketLooping specifies where an outbound packet should be sent.
@@ -182,12 +232,17 @@ type NetworkEndpoint interface {
MaxHeaderLength() uint16
// WritePacket writes a packet to the given destination address and
- // protocol.
- WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams, loop PacketLooping) *tcpip.Error
+ // protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have
+ // already been set.
+ WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error
+
+ // WritePackets writes packets to the given destination address and
+ // protocol. pkts must not be zero length.
+ WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams, loop PacketLooping) (int, *tcpip.Error)
// WriteHeaderIncludedPacket writes a packet that includes a network
// header to the given destination address.
- WriteHeaderIncludedPacket(r *Route, payload buffer.VectorisedView, loop PacketLooping) *tcpip.Error
+ WriteHeaderIncludedPacket(r *Route, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error
// ID returns the network protocol endpoint ID.
ID() *NetworkEndpointID
@@ -199,8 +254,10 @@ type NetworkEndpoint interface {
NICID() tcpip.NICID
// HandlePacket is called by the link layer when new packets arrive to
- // this network endpoint.
- HandlePacket(r *Route, vv buffer.VectorisedView)
+ // this network endpoint. It sets pkt.NetworkHeader.
+ //
+ // HandlePacket takes ownership of pkt.
+ HandlePacket(r *Route, pkt tcpip.PacketBuffer)
// Close is called when the endpoint is reomved from a stack.
Close()
@@ -225,7 +282,7 @@ type NetworkProtocol interface {
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
// NewEndpoint creates a new endpoint of this protocol.
- NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error)
+ NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error)
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -242,9 +299,15 @@ type NetworkProtocol interface {
// packets to the appropriate network endpoint after it has been handled by
// the data link layer.
type NetworkDispatcher interface {
- // DeliverNetworkPacket finds the appropriate network protocol
- // endpoint and hands the packet over for further processing.
- DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView)
+ // DeliverNetworkPacket finds the appropriate network protocol endpoint
+ // and hands the packet over for further processing.
+ //
+ // pkt.LinkHeader may or may not be set before calling
+ // DeliverNetworkPacket. Some packets do not have link headers (e.g.
+ // packets sent via loopback), and won't have the field set.
+ //
+ // DeliverNetworkPacket takes ownership of pkt.
+ DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer)
}
// LinkEndpointCapabilities is the type associated with the capabilities
@@ -266,12 +329,18 @@ const (
CapabilitySaveRestore
CapabilityDisconnectOk
CapabilityLoopback
- CapabilityGSO
+ CapabilityHardwareGSO
+
+ // CapabilitySoftwareGSO indicates the link endpoint supports of sending
+ // multiple packets using a single call (LinkEndpoint.WritePackets).
+ CapabilitySoftwareGSO
)
// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
// ethernet, loopback, raw) and used by network layer protocols to send packets
-// out through the implementer's data link endpoint.
+// out through the implementer's data link endpoint. When a link header exists,
+// it sets each tcpip.PacketBuffer's LinkHeader field before passing it up the
+// stack.
type LinkEndpoint interface {
// MTU is the maximum transmission unit for this endpoint. This is
// usually dictated by the backing physical network; when such a
@@ -293,13 +362,27 @@ type LinkEndpoint interface {
// link endpoint.
LinkAddress() tcpip.LinkAddress
- // WritePacket writes a packet with the given protocol through the given
- // route.
+ // WritePacket writes a packet with the given protocol through the
+ // given route. It sets pkt.LinkHeader if a link layer header exists.
+ // pkt.NetworkHeader and pkt.TransportHeader must have already been
+ // set.
//
// To participate in transparent bridging, a LinkEndpoint implementation
// should call eth.Encode with header.EthernetFields.SrcAddr set to
// r.LocalLinkAddress if it is provided.
- WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error
+ WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error
+
+ // WritePackets writes packets with the given protocol through the
+ // given route. pkts must not be zero length.
+ //
+ // Right now, WritePackets is used only when the software segmentation
+ // offload is enabled. If it will be used for something else, it may
+ // require to change syscall filters.
+ WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
+
+ // WriteRawPacket writes a packet directly to the link. The packet
+ // should already have an ethernet header.
+ WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error
// Attach attaches the data link layer endpoint to the network-layer
// dispatcher of the stack.
@@ -324,13 +407,14 @@ type LinkEndpoint interface {
type InjectableLinkEndpoint interface {
LinkEndpoint
- // Inject injects an inbound packet.
- Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView)
+ // InjectInbound injects an inbound packet.
+ InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer)
- // WriteRawPacket writes a fully formed outbound packet directly to the link.
+ // InjectOutbound writes a fully formed outbound packet directly to the
+ // link.
//
// dest is used by endpoints with multiple raw destinations.
- WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error
+ InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error
}
// A LinkAddressResolver is an extension to a NetworkProtocol that
@@ -359,10 +443,10 @@ type LinkAddressResolver interface {
type LinkAddressCache interface {
// CheckLocalAddress determines if the given local address exists, and if it
// does not exist.
- CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID
+ CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID
// AddLinkAddress adds a link address to the cache.
- AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress)
+ AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress)
// GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC).
// If the LinkEndpoint requests address resolution and there is a LinkAddressResolver
@@ -373,17 +457,22 @@ type LinkAddressCache interface {
// If address resolution is required, ErrNoLinkAddress and a notification channel is
// returned for the top level caller to block. Channel is closed once address resolution
// is complete (success or not).
- GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error)
+ GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error)
// RemoveWaker removes a waker that has been added in GetLinkAddress().
- RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
+ RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
}
-// UnassociatedEndpointFactory produces endpoints for writing packets not
-// associated with a particular transport protocol. Such endpoints can be used
-// to write arbitrary packets that include the IP header.
-type UnassociatedEndpointFactory interface {
- NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+// RawFactory produces endpoints for writing various types of raw packets.
+type RawFactory interface {
+ // NewUnassociatedEndpoint produces endpoints for writing packets not
+ // associated with a particular transport protocol. Such endpoints can
+ // be used to write arbitrary packets that include the network header.
+ NewUnassociatedEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+
+ // NewPacketEndpoint produces endpoints for reading and writing packets
+ // that include network and (when cooked is false) link layer headers.
+ NewPacketEndpoint(stack *Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
}
// GSOType is the type of GSO segments.
@@ -394,8 +483,14 @@ type GSOType int
// Types of gso segments.
const (
GSONone GSOType = iota
+
+ // Hardware GSO types:
GSOTCPv4
GSOTCPv6
+
+ // GSOSW is used for software GSO segments which have to be sent by
+ // endpoint.WritePackets.
+ GSOSW
)
// GSO contains generic segmentation offload properties.
@@ -423,3 +518,7 @@ type GSOEndpoint interface {
// GSOMaxSize returns the maximum GSO packet size.
GSOMaxSize() uint32
}
+
+// SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment.
+// This isn't a hard limit, because it is never set into packet headers.
+const SoftwareGSOMaxSize = (1 << 16)
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index e72373964..34307ae07 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -17,7 +17,6 @@ package stack
import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -47,8 +46,8 @@ type Route struct {
// starts.
ref *referencedNetworkEndpoint
- // loop controls where WritePacket should send packets.
- loop PacketLooping
+ // Loop controls where WritePacket should send packets.
+ Loop PacketLooping
}
// makeRoute initializes a new route. It takes ownership of the provided
@@ -69,7 +68,7 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
LocalLinkAddress: localLinkAddr,
RemoteAddress: remoteAddr,
ref: ref,
- loop: loop,
+ Loop: loop,
}
}
@@ -154,34 +153,54 @@ func (r *Route) IsResolutionRequired() bool {
}
// WritePacket writes the packet through the given route.
-func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams) *tcpip.Error {
+func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error {
if !r.ref.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
- err := r.ref.ep.WritePacket(r, gso, hdr, payload, params, r.loop)
+ err := r.ref.ep.WritePacket(r, gso, params, r.Loop, pkt)
if err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
} else {
r.ref.nic.stats.Tx.Packets.Increment()
- r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + payload.Size()))
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size()))
}
return err
}
+// WritePackets writes the set of packets through the given route.
+func (r *Route) WritePackets(gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) {
+ if !r.ref.isValidForOutgoing() {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ n, err := r.ref.ep.WritePackets(r, gso, pkts, params, r.Loop)
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(pkts) - n))
+ }
+ r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n))
+ payloadSize := 0
+ for i := 0; i < n; i++ {
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkts[i].Header.UsedLength()))
+ payloadSize += pkts[i].DataSize
+ }
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payloadSize))
+ return n, err
+}
+
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
-func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error {
+func (r *Route) WriteHeaderIncludedPacket(pkt tcpip.PacketBuffer) *tcpip.Error {
if !r.ref.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
- if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil {
+ if err := r.ref.ep.WriteHeaderIncludedPacket(r, r.Loop, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
r.ref.nic.stats.Tx.Packets.Increment()
- r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payload.Size()))
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
return nil
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index a199bc1cc..0e88643a4 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -22,6 +22,7 @@ package stack
import (
"encoding/binary"
"sync"
+ "sync/atomic"
"time"
"golang.org/x/time/rate"
@@ -50,7 +51,7 @@ const (
type transportProtocolState struct {
proto TransportProtocol
- defaultHandler func(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool
+ defaultHandler func(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool
}
// TCPProbeFunc is the expected function type for a TCP probe function to be
@@ -344,6 +345,13 @@ type ResumableEndpoint interface {
Resume(*Stack)
}
+// uniqueIDGenerator is a default unique ID generator.
+type uniqueIDGenerator uint64
+
+func (u *uniqueIDGenerator) UniqueID() uint64 {
+ return atomic.AddUint64((*uint64)(u), 1)
+}
+
// Stack is a networking stack, with all supported protocols, NICs, and route
// table.
type Stack struct {
@@ -351,10 +359,9 @@ type Stack struct {
networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol
linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
- // unassociatedFactory creates unassociated endpoints. If nil, raw
- // endpoints are disabled. It is set during Stack creation and is
- // immutable.
- unassociatedFactory UnassociatedEndpointFactory
+ // rawFactory creates raw endpoints. If nil, raw endpoints are
+ // disabled. It is set during Stack creation and is immutable.
+ rawFactory RawFactory
demux *transportDemuxer
@@ -362,9 +369,10 @@ type Stack struct {
linkAddrCache *linkAddrCache
- mu sync.RWMutex
- nics map[tcpip.NICID]*NIC
- forwarding bool
+ mu sync.RWMutex
+ nics map[tcpip.NICID]*NIC
+ forwarding bool
+ cleanupEndpoints map[TransportEndpoint]struct{}
// route is the route table passed in by the user via SetRouteTable(),
// it is used by FindRoute() to build a route for a specific
@@ -394,14 +402,31 @@ type Stack struct {
// by the stack.
icmpRateLimiter *ICMPRateLimiter
- // portSeed is a one-time random value initialized at stack startup
+ // seed is a one-time random value initialized at stack startup
// and is used to seed the TCP port picking on active connections
//
// TODO(gvisor.dev/issue/940): S/R this field.
- portSeed uint32
+ seed uint32
- // ndpConfigs is the NDP configurations used by interfaces.
+ // ndpConfigs is the default NDP configurations used by interfaces.
ndpConfigs NDPConfigurations
+
+ // autoGenIPv6LinkLocal determines whether or not the stack will attempt
+ // to auto-generate an IPv6 link-local address for newly enabled NICs.
+ // See the AutoGenIPv6LinkLocal field of Options for more details.
+ autoGenIPv6LinkLocal bool
+
+ // ndpDisp is the NDP event dispatcher that is used to send the netstack
+ // integrator NDP related events.
+ ndpDisp NDPDispatcher
+
+ // uniqueIDGenerator is a generator of unique identifiers.
+ uniqueIDGenerator UniqueID
+}
+
+// UniqueID is an abstract generator of unique identifiers.
+type UniqueID interface {
+ UniqueID() uint64
}
// Options contains optional Stack configuration.
@@ -425,16 +450,35 @@ type Options struct {
// stack (false).
HandleLocal bool
- // UnassociatedFactory produces unassociated endpoints raw endpoints.
- // Raw endpoints are enabled only if this is non-nil.
- UnassociatedFactory UnassociatedEndpointFactory
+ // UniqueID is an optional generator of unique identifiers.
+ UniqueID UniqueID
- // NDPConfigs is the NDP configurations used by interfaces.
+ // NDPConfigs is the default NDP configurations used by interfaces.
//
// By default, NDPConfigs will have a zero value for its
// DupAddrDetectTransmits field, implying that DAD will not be performed
// before assigning an address to a NIC.
NDPConfigs NDPConfigurations
+
+ // AutoGenIPv6LinkLocal determins whether or not the stack will attempt
+ // to auto-generate an IPv6 link-local address for newly enabled NICs.
+ // Note, setting this to true does not mean that a link-local address
+ // will be assigned right away, or at all. If Duplicate Address
+ // Detection is enabled, an address will only be assigned if it
+ // successfully resolves. If it fails, no further attempt will be made
+ // to auto-generate an IPv6 link-local address.
+ //
+ // The generated link-local address will follow RFC 4291 Appendix A
+ // guidelines.
+ AutoGenIPv6LinkLocal bool
+
+ // NDPDisp is the NDP event dispatcher that an integrator can provide to
+ // receive NDP related events.
+ NDPDisp NDPDispatcher
+
+ // RawFactory produces raw endpoints. Raw endpoints are enabled only if
+ // this is non-nil.
+ RawFactory RawFactory
}
// TransportEndpointInfo holds useful information about a transport endpoint
@@ -481,22 +525,30 @@ func New(opts Options) *Stack {
clock = &tcpip.StdClock{}
}
+ if opts.UniqueID == nil {
+ opts.UniqueID = new(uniqueIDGenerator)
+ }
+
// Make sure opts.NDPConfigs contains valid values only.
opts.NDPConfigs.validate()
s := &Stack{
- transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
- networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
- linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
- nics: make(map[tcpip.NICID]*NIC),
- linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
- PortManager: ports.NewPortManager(),
- clock: clock,
- stats: opts.Stats.FillIn(),
- handleLocal: opts.HandleLocal,
- icmpRateLimiter: NewICMPRateLimiter(),
- portSeed: generateRandUint32(),
- ndpConfigs: opts.NDPConfigs,
+ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
+ networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
+ linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
+ nics: make(map[tcpip.NICID]*NIC),
+ cleanupEndpoints: make(map[TransportEndpoint]struct{}),
+ linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
+ PortManager: ports.NewPortManager(),
+ clock: clock,
+ stats: opts.Stats.FillIn(),
+ handleLocal: opts.HandleLocal,
+ icmpRateLimiter: NewICMPRateLimiter(),
+ seed: generateRandUint32(),
+ ndpConfigs: opts.NDPConfigs,
+ autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
+ uniqueIDGenerator: opts.UniqueID,
+ ndpDisp: opts.NDPDisp,
}
// Add specified network protocols.
@@ -514,8 +566,8 @@ func New(opts Options) *Stack {
}
}
- // Add the factory for unassociated endpoints, if present.
- s.unassociatedFactory = opts.UnassociatedFactory
+ // Add the factory for raw endpoints, if present.
+ s.rawFactory = opts.RawFactory
// Create the global transport demuxer.
s.demux = newTransportDemuxer(s)
@@ -523,6 +575,11 @@ func New(opts Options) *Stack {
return s
}
+// UniqueID returns a unique identifier.
+func (s *Stack) UniqueID() uint64 {
+ return s.uniqueIDGenerator.UniqueID()
+}
+
// SetNetworkProtocolOption allows configuring individual protocol level
// options. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation or the provided value
@@ -584,7 +641,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber,
//
// It must be called only during initialization of the stack. Changing it as the
// stack is operating is not supported.
-func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.View, buffer.VectorisedView) bool) {
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, tcpip.PacketBuffer) bool) {
state := s.transportProtocols[p]
if state != nil {
state.defaultHandler = h
@@ -650,12 +707,12 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp
// protocol. Raw endpoints receive all traffic for a given protocol regardless
// of address.
func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
- if s.unassociatedFactory == nil {
+ if s.rawFactory == nil {
return nil, tcpip.ErrNotPermitted
}
if !associated {
- return s.unassociatedFactory.NewUnassociatedRawEndpoint(s, network, transport, waiterQueue)
+ return s.rawFactory.NewUnassociatedEndpoint(s, network, transport, waiterQueue)
}
t, ok := s.transportProtocols[transport]
@@ -666,6 +723,16 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network
return t.proto.NewRawEndpoint(s, network, waiterQueue)
}
+// NewPacketEndpoint creates a new packet endpoint listening for the given
+// netProto.
+func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ if s.rawFactory == nil {
+ return nil, tcpip.ErrNotPermitted
+ }
+
+ return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue)
+}
+
// createNIC creates a NIC with the provided id and link-layer endpoint, and
// optionally enable it.
func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error {
@@ -988,13 +1055,13 @@ func (s *Stack) CheckNetworkProtocol(protocol tcpip.NetworkProtocolNumber) bool
// CheckLocalAddress determines if the given local address exists, and if it
// does, returns the id of the NIC it's bound to. Returns 0 if the address
// does not exist.
-func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID {
+func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID {
s.mu.RLock()
defer s.mu.RUnlock()
// If a NIC is specified, we try to find the address there only.
- if nicid != 0 {
- nic := s.nics[nicid]
+ if nicID != 0 {
+ nic := s.nics[nicID]
if nic == nil {
return 0
}
@@ -1053,35 +1120,35 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error {
}
// AddLinkAddress adds a link address to the stack link cache.
-func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
- fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
+func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
+ fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
s.linkAddrCache.add(fullAddr, linkAddr)
// TODO: provide a way for a transport endpoint to receive a signal
// that AddLinkAddress for a particular address has been called.
}
// GetLinkAddress implements LinkAddressCache.GetLinkAddress.
-func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
s.mu.RLock()
- nic := s.nics[nicid]
+ nic := s.nics[nicID]
if nic == nil {
s.mu.RUnlock()
return "", nil, tcpip.ErrUnknownNICID
}
s.mu.RUnlock()
- fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
+ fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
linkRes := s.linkAddrResolvers[protocol]
return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker)
}
// RemoveWaker implements LinkAddressCache.RemoveWaker.
-func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) {
+func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) {
s.mu.RLock()
defer s.mu.RUnlock()
- if nic := s.nics[nicid]; nic == nil {
- fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
+ if nic := s.nics[nicID]; nic == nil {
+ fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
s.linkAddrCache.removeWaker(fullAddr, waker)
}
}
@@ -1100,6 +1167,31 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip
s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
}
+// StartTransportEndpointCleanup removes the endpoint with the given id from
+// the stack transport dispatcher. It also transitions it to the cleanup stage.
+func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.cleanupEndpoints[ep] = struct{}{}
+
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
+}
+
+// CompleteTransportEndpointCleanup removes the endpoint from the cleanup
+// stage.
+func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) {
+ s.mu.Lock()
+ delete(s.cleanupEndpoints, ep)
+ s.mu.Unlock()
+}
+
+// FindTransportEndpoint finds an endpoint that most closely matches the provided
+// id. If no endpoint is found it returns nil.
+func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
+ return s.demux.findTransportEndpoint(netProto, transProto, id, r)
+}
+
// RegisterRawTransportEndpoint registers the given endpoint with the stack
// transport dispatcher. Received packets that match the provided transport
// protocol will be delivered to the given endpoint.
@@ -1121,6 +1213,69 @@ func (s *Stack) RegisterRestoredEndpoint(e ResumableEndpoint) {
s.mu.Unlock()
}
+// RegisteredEndpoints returns all endpoints which are currently registered.
+func (s *Stack) RegisteredEndpoints() []TransportEndpoint {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ var es []TransportEndpoint
+ for _, e := range s.demux.protocol {
+ es = append(es, e.transportEndpoints()...)
+ }
+ return es
+}
+
+// CleanupEndpoints returns endpoints currently in the cleanup state.
+func (s *Stack) CleanupEndpoints() []TransportEndpoint {
+ s.mu.Lock()
+ es := make([]TransportEndpoint, 0, len(s.cleanupEndpoints))
+ for e := range s.cleanupEndpoints {
+ es = append(es, e)
+ }
+ s.mu.Unlock()
+ return es
+}
+
+// RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful
+// for restoring a stack after a save.
+func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) {
+ s.mu.Lock()
+ for _, e := range es {
+ s.cleanupEndpoints[e] = struct{}{}
+ }
+ s.mu.Unlock()
+}
+
+// Close closes all currently registered transport endpoints.
+//
+// Endpoints created or modified during this call may not get closed.
+func (s *Stack) Close() {
+ for _, e := range s.RegisteredEndpoints() {
+ e.Close()
+ }
+}
+
+// Wait waits for all transport and link endpoints to halt their worker
+// goroutines.
+//
+// Endpoints created or modified during this call may not get waited on.
+//
+// Note that link endpoints must be stopped via an implementation specific
+// mechanism.
+func (s *Stack) Wait() {
+ for _, e := range s.RegisteredEndpoints() {
+ e.Wait()
+ }
+ for _, e := range s.CleanupEndpoints() {
+ e.Wait()
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ for _, n := range s.nics {
+ n.linkEP.Wait()
+ }
+}
+
// Resume restarts the stack after a restore. This must be called after the
// entire system has been restored.
func (s *Stack) Resume() {
@@ -1135,6 +1290,109 @@ func (s *Stack) Resume() {
}
}
+// RegisterPacketEndpoint registers ep with the stack, causing it to receive
+// all traffic of the specified netProto on the given NIC. If nicID is 0, it
+// receives traffic from every NIC.
+func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // If no NIC is specified, capture on all devices.
+ if nicID == 0 {
+ // Register with each NIC.
+ for _, nic := range s.nics {
+ if err := nic.registerPacketEndpoint(netProto, ep); err != nil {
+ s.unregisterPacketEndpointLocked(0, netProto, ep)
+ return err
+ }
+ }
+ return nil
+ }
+
+ // Capture on a specific device.
+ nic, ok := s.nics[nicID]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+ if err := nic.registerPacketEndpoint(netProto, ep); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// UnregisterPacketEndpoint unregisters ep for packets of the specified
+// netProto from the specified NIC. If nicID is 0, ep is unregistered from all
+// NICs.
+func (s *Stack) UnregisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.unregisterPacketEndpointLocked(nicID, netProto, ep)
+}
+
+func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) {
+ // If no NIC is specified, unregister on all devices.
+ if nicID == 0 {
+ // Unregister with each NIC.
+ for _, nic := range s.nics {
+ nic.unregisterPacketEndpoint(netProto, ep)
+ }
+ return
+ }
+
+ // Unregister in a single device.
+ nic, ok := s.nics[nicID]
+ if !ok {
+ return
+ }
+ nic.unregisterPacketEndpoint(netProto, ep)
+}
+
+// WritePacket writes data directly to the specified NIC. It adds an ethernet
+// header based on the arguments.
+func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error {
+ s.mu.Lock()
+ nic, ok := s.nics[nicID]
+ s.mu.Unlock()
+ if !ok {
+ return tcpip.ErrUnknownDevice
+ }
+
+ // Add our own fake ethernet header.
+ ethFields := header.EthernetFields{
+ SrcAddr: nic.linkEP.LinkAddress(),
+ DstAddr: dst,
+ Type: netProto,
+ }
+ fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
+ fakeHeader.Encode(&ethFields)
+ vv := buffer.View(fakeHeader).ToVectorisedView()
+ vv.Append(payload)
+
+ if err := nic.linkEP.WriteRawPacket(vv); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// WriteRawPacket writes data directly to the specified NIC without adding any
+// headers.
+func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error {
+ s.mu.Lock()
+ nic, ok := s.nics[nicID]
+ s.mu.Unlock()
+ if !ok {
+ return tcpip.ErrUnknownDevice
+ }
+
+ if err := nic.linkEP.WriteRawPacket(payload); err != nil {
+ return err
+ }
+
+ return nil
+}
+
// NetworkProtocolInstance returns the protocol instance in the stack for the
// specified network protocol. This method is public for protocol implementers
// and tests to use.
@@ -1286,12 +1544,47 @@ func (s *Stack) DupTentativeAddrDetected(id tcpip.NICID, addr tcpip.Address) *tc
return nic.dupTentativeAddrDetected(addr)
}
-// PortSeed returns a 32 bit value that can be used as a seed value for port
-// picking.
+// SetNDPConfigurations sets the per-interface NDP configurations on the NIC
+// with ID id to c.
+//
+// Note, if c contains invalid NDP configuration values, it will be fixed to
+// use default values for the erroneous values.
+func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.setNDPConfigs(c)
+
+ return nil
+}
+
+// HandleNDPRA provides a NIC with ID id a validated NDP Router Advertisement
+// message that it needs to handle.
+func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRouterAdvert) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.handleNDPRA(ip, ra)
+
+ return nil
+}
+
+// Seed returns a 32 bit value that can be used as a seed value for port
+// picking, ISN generation etc.
//
// NOTE: The seed is generated once during stack initialization only.
-func (s *Stack) PortSeed() uint32 {
- return s.portSeed
+func (s *Stack) Seed() uint32 {
+ return s.seed
}
func generateRandUint32() uint32 {
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 10fd1065f..8fc034ca1 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -24,11 +24,14 @@ import (
"sort"
"strings"
"testing"
+ "time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -55,7 +58,7 @@ const (
// use the first three: destination address, source address, and transport
// protocol. They're all one byte fields to simplify parsing.
type fakeNetworkEndpoint struct {
- nicid tcpip.NICID
+ nicID tcpip.NICID
id stack.NetworkEndpointID
prefixLen int
proto *fakeNetworkProtocol
@@ -68,7 +71,7 @@ func (f *fakeNetworkEndpoint) MTU() uint32 {
}
func (f *fakeNetworkEndpoint) NICID() tcpip.NICID {
- return f.nicid
+ return f.nicID
}
func (f *fakeNetworkEndpoint) PrefixLen() int {
@@ -83,28 +86,28 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID {
return &f.id
}
-func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
+func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
// Increment the received packet count in the protocol descriptor.
f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
// Consume the network header.
- b := vv.First()
- vv.TrimFront(fakeNetHeaderLen)
+ b := pkt.Data.First()
+ pkt.Data.TrimFront(fakeNetHeaderLen)
// Handle control packets.
if b[2] == uint8(fakeControlProtocol) {
- nb := vv.First()
+ nb := pkt.Data.First()
if len(nb) < fakeNetHeaderLen {
return
}
- vv.TrimFront(fakeNetHeaderLen)
- f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, vv)
+ pkt.Data.TrimFront(fakeNetHeaderLen)
+ f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt)
return
}
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), buffer.View([]byte{}), vv)
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt)
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -119,32 +122,38 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return f.ep.Capabilities()
}
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error {
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
// Increment the sent packet count in the protocol descriptor.
f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
// Add the protocol's header to the packet and send it to the link
// endpoint.
- b := hdr.Prepend(fakeNetHeaderLen)
+ b := pkt.Header.Prepend(fakeNetHeaderLen)
b[0] = r.RemoteAddress[0]
b[1] = f.id.LocalAddress[0]
b[2] = byte(params.Protocol)
if loop&stack.PacketLoop != 0 {
- views := make([]buffer.View, 1, 1+len(payload.Views()))
- views[0] = hdr.View()
- views = append(views, payload.Views()...)
- vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
- f.HandlePacket(r, vv)
+ views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
+ views[0] = pkt.Header.View()
+ views = append(views, pkt.Data.Views()...)
+ f.HandlePacket(r, tcpip.PacketBuffer{
+ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
+ })
}
if loop&stack.PacketOut == 0 {
return nil
}
- return f.ep.WritePacket(r, gso, hdr, payload, fakeNetNumber)
+ return f.ep.WritePacket(r, gso, fakeNetNumber, pkt)
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
+ panic("not implemented")
}
-func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
+func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -189,9 +198,9 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres
return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
}
-func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
return &fakeNetworkEndpoint{
- nicid: nicid,
+ nicID: nicID,
id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
prefixLen: addrWithPrefix.PrefixLen,
proto: f,
@@ -251,7 +260,9 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet with wrong address is not delivered.
buf[0] = 3
- ep.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if fakeNet.packetCount[1] != 0 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
}
@@ -261,7 +272,9 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet is delivered to first endpoint.
buf[0] = 1
- ep.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -271,7 +284,9 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet is delivered to second endpoint.
buf[0] = 2
- ep.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -280,7 +295,9 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is not delivered if protocol number is wrong.
- ep.Inject(fakeNetNumber-1, buf.ToVectorisedView())
+ ep.InjectInbound(fakeNetNumber-1, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -290,7 +307,9 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet that is too small is dropped.
buf.CapLength(2)
- ep.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -310,7 +329,10 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro
func send(r stack.Route, payload buffer.View) *tcpip.Error {
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
- return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS})
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: payload.ToVectorisedView(),
+ })
}
func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) {
@@ -365,7 +387,9 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b
func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) {
t.Helper()
- ep.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if got := fakeNet.PacketCount(localAddrByte); got != want {
t.Errorf("receive packet count: got = %d, want %d", got, want)
}
@@ -660,11 +684,11 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
}
}
-func verifyAddress(t *testing.T, s *stack.Stack, nicid tcpip.NICID, addr tcpip.Address) {
+func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.Address) {
t.Helper()
- info, ok := s.NICInfo()[nicid]
+ info, ok := s.NICInfo()[nicID]
if !ok {
- t.Fatalf("NICInfo() failed to find nicid=%d", nicid)
+ t.Fatalf("NICInfo() failed to find nicID=%d", nicID)
}
if len(addr) == 0 {
// No address given, verify that there is no address assigned to the NIC.
@@ -697,7 +721,7 @@ func TestEndpointExpiration(t *testing.T) {
localAddrByte byte = 0x01
remoteAddr tcpip.Address = "\x03"
noAddr tcpip.Address = ""
- nicid tcpip.NICID = 1
+ nicID tcpip.NICID = 1
)
localAddr := tcpip.Address([]byte{localAddrByte})
@@ -709,7 +733,7 @@ func TestEndpointExpiration(t *testing.T) {
})
ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, ep); err != nil {
+ if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -726,13 +750,13 @@ func TestEndpointExpiration(t *testing.T) {
buf[0] = localAddrByte
if promiscuous {
- if err := s.SetPromiscuousMode(nicid, true); err != nil {
+ if err := s.SetPromiscuousMode(nicID, true); err != nil {
t.Fatal("SetPromiscuousMode failed:", err)
}
}
if spoofing {
- if err := s.SetSpoofing(nicid, true); err != nil {
+ if err := s.SetSpoofing(nicID, true); err != nil {
t.Fatal("SetSpoofing failed:", err)
}
}
@@ -740,7 +764,7 @@ func TestEndpointExpiration(t *testing.T) {
// 1. No Address yet, send should only work for spoofing, receive for
// promiscuous mode.
//-----------------------
- verifyAddress(t, s, nicid, noAddr)
+ verifyAddress(t, s, nicID, noAddr)
if promiscuous {
testRecv(t, fakeNet, localAddrByte, ep, buf)
} else {
@@ -755,20 +779,20 @@ func TestEndpointExpiration(t *testing.T) {
// 2. Add Address, everything should work.
//-----------------------
- if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
t.Fatal("AddAddress failed:", err)
}
- verifyAddress(t, s, nicid, localAddr)
+ verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSendTo(t, s, remoteAddr, ep, nil)
// 3. Remove the address, send should only work for spoofing, receive
// for promiscuous mode.
//-----------------------
- if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ if err := s.RemoveAddress(nicID, localAddr); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
- verifyAddress(t, s, nicid, noAddr)
+ verifyAddress(t, s, nicID, noAddr)
if promiscuous {
testRecv(t, fakeNet, localAddrByte, ep, buf)
} else {
@@ -783,10 +807,10 @@ func TestEndpointExpiration(t *testing.T) {
// 4. Add Address back, everything should work again.
//-----------------------
- if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
t.Fatal("AddAddress failed:", err)
}
- verifyAddress(t, s, nicid, localAddr)
+ verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSendTo(t, s, remoteAddr, ep, nil)
@@ -804,10 +828,10 @@ func TestEndpointExpiration(t *testing.T) {
// 6. Remove the address. Send should only work for spoofing, receive
// for promiscuous mode.
//-----------------------
- if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ if err := s.RemoveAddress(nicID, localAddr); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
- verifyAddress(t, s, nicid, noAddr)
+ verifyAddress(t, s, nicID, noAddr)
if promiscuous {
testRecv(t, fakeNet, localAddrByte, ep, buf)
} else {
@@ -823,10 +847,10 @@ func TestEndpointExpiration(t *testing.T) {
// 7. Add Address back, everything should work again.
//-----------------------
- if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
t.Fatal("AddAddress failed:", err)
}
- verifyAddress(t, s, nicid, localAddr)
+ verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSendTo(t, s, remoteAddr, ep, nil)
testSend(t, r, ep, nil)
@@ -834,17 +858,17 @@ func TestEndpointExpiration(t *testing.T) {
// 8. Remove the route, sendTo/recv should still work.
//-----------------------
r.Release()
- verifyAddress(t, s, nicid, localAddr)
+ verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSendTo(t, s, remoteAddr, ep, nil)
// 9. Remove the address. Send should only work for spoofing, receive
// for promiscuous mode.
//-----------------------
- if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ if err := s.RemoveAddress(nicID, localAddr); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
- verifyAddress(t, s, nicid, noAddr)
+ verifyAddress(t, s, nicID, noAddr)
if promiscuous {
testRecv(t, fakeNet, localAddrByte, ep, buf)
} else {
@@ -1637,12 +1661,12 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto
}
func TestAddAddress(t *testing.T) {
- const nicid = 1
+ const nicID = 1
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, ep); err != nil {
+ if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1650,7 +1674,7 @@ func TestAddAddress(t *testing.T) {
expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2)
for _, addrLen := range []int{4, 16} {
address := addrGen.next(addrLen)
- if err := s.AddAddress(nicid, fakeNetNumber, address); err != nil {
+ if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil {
t.Fatalf("AddAddress(address=%s) failed: %s", address, err)
}
expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
@@ -1659,17 +1683,17 @@ func TestAddAddress(t *testing.T) {
})
}
- gotAddresses := s.AllAddresses()[nicid]
+ gotAddresses := s.AllAddresses()[nicID]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddProtocolAddress(t *testing.T) {
- const nicid = 1
+ const nicID = 1
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, ep); err != nil {
+ if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1686,24 +1710,24 @@ func TestAddProtocolAddress(t *testing.T) {
PrefixLen: prefixLen,
},
}
- if err := s.AddProtocolAddress(nicid, protocolAddress); err != nil {
+ if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil {
t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err)
}
expectedAddresses = append(expectedAddresses, protocolAddress)
}
}
- gotAddresses := s.AllAddresses()[nicid]
+ gotAddresses := s.AllAddresses()[nicID]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddAddressWithOptions(t *testing.T) {
- const nicid = 1
+ const nicID = 1
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, ep); err != nil {
+ if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1714,7 +1738,7 @@ func TestAddAddressWithOptions(t *testing.T) {
for _, addrLen := range addrLenRange {
for _, behavior := range behaviorRange {
address := addrGen.next(addrLen)
- if err := s.AddAddressWithOptions(nicid, fakeNetNumber, address, behavior); err != nil {
+ if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil {
t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err)
}
expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
@@ -1724,17 +1748,17 @@ func TestAddAddressWithOptions(t *testing.T) {
}
}
- gotAddresses := s.AllAddresses()[nicid]
+ gotAddresses := s.AllAddresses()[nicID]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddProtocolAddressWithOptions(t *testing.T) {
- const nicid = 1
+ const nicID = 1
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, ep); err != nil {
+ if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1753,7 +1777,7 @@ func TestAddProtocolAddressWithOptions(t *testing.T) {
PrefixLen: prefixLen,
},
}
- if err := s.AddProtocolAddressWithOptions(nicid, protocolAddress, behavior); err != nil {
+ if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil {
t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err)
}
expectedAddresses = append(expectedAddresses, protocolAddress)
@@ -1761,7 +1785,7 @@ func TestAddProtocolAddressWithOptions(t *testing.T) {
}
}
- gotAddresses := s.AllAddresses()[nicid]
+ gotAddresses := s.AllAddresses()[nicID]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
@@ -1787,7 +1811,9 @@ func TestNICStats(t *testing.T) {
// Send a packet to address 1.
buf := buffer.NewView(30)
- ep1.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want {
t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want)
}
@@ -1847,7 +1873,9 @@ func TestNICForwarding(t *testing.T) {
// Send a packet to address 3.
buf := buffer.NewView(30)
buf[0] = 3
- ep1.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
select {
case <-ep2.C:
@@ -1864,3 +1892,297 @@ func TestNICForwarding(t *testing.T) {
t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
}
}
+
+// TestNICAutoGenAddr tests the auto-generation of IPv6 link-local addresses
+// (or lack there-of if disabled (default)). Note, DAD will be disabled in
+// these tests.
+func TestNICAutoGenAddr(t *testing.T) {
+ tests := []struct {
+ name string
+ autoGen bool
+ linkAddr tcpip.LinkAddress
+ shouldGen bool
+ }{
+ {
+ "Disabled",
+ false,
+ linkAddr1,
+ false,
+ },
+ {
+ "Enabled",
+ true,
+ linkAddr1,
+ true,
+ },
+ {
+ "Nil MAC",
+ true,
+ tcpip.LinkAddress([]byte(nil)),
+ false,
+ },
+ {
+ "Empty MAC",
+ true,
+ tcpip.LinkAddress(""),
+ false,
+ },
+ {
+ "Invalid MAC",
+ true,
+ tcpip.LinkAddress("\x01\x02\x03"),
+ false,
+ },
+ {
+ "Multicast MAC",
+ true,
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ false,
+ },
+ {
+ "Unspecified MAC",
+ true,
+ tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"),
+ false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ }
+
+ if test.autoGen {
+ // Only set opts.AutoGenIPv6LinkLocal when
+ // test.autoGen is true because
+ // opts.AutoGenIPv6LinkLocal should be false by
+ // default.
+ opts.AutoGenIPv6LinkLocal = true
+ }
+
+ e := channel.New(10, 1280, test.linkAddr)
+ s := stack.New(opts)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+
+ if test.shouldGen {
+ // Should have auto-generated an address and
+ // resolved immediately (DAD is disabled).
+ if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want)
+ }
+ } else {
+ // Should not have auto-generated an address.
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+ }
+ })
+ }
+}
+
+// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6
+// link-local addresses will only be assigned after the DAD process resolves.
+func TestNICAutoGenAddrDoesDAD(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ ndpConfigs := stack.DefaultNDPConfigurations()
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: ndpConfigs,
+ AutoGenIPv6LinkLocal: true,
+ NDPDisp: &ndpDisp,
+ }
+
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ // Address should not be considered bound to the
+ // NIC yet (DAD ongoing).
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ linkLocalAddr := header.LinkLocalAddr(linkAddr1)
+
+ // Wait for DAD to resolve.
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // We should get a resolution event after 1s (default time to
+ // resolve as per default NDP configurations). Waiting for that
+ // resolution time + an extra 1s without a resolution event
+ // means something is wrong.
+ t.Fatal("timed out waiting for DAD resolution")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicID != 1 {
+ t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID)
+ }
+ if e.addr != linkLocalAddr {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr)
+ }
+ if !e.resolved {
+ t.Fatal("got DAD event w/ resolved = false, want = true")
+ }
+ }
+ addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+ if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want)
+ }
+}
+
+// TestNewPEB tests that a new PrimaryEndpointBehavior value (peb) is respected
+// when an address's kind gets "promoted" to permanent from permanentExpired.
+func TestNewPEBOnPromotionToPermanent(t *testing.T) {
+ pebs := []stack.PrimaryEndpointBehavior{
+ stack.NeverPrimaryEndpoint,
+ stack.CanBePrimaryEndpoint,
+ stack.FirstPrimaryEndpoint,
+ }
+
+ for _, pi := range pebs {
+ for _, ps := range pebs {
+ t.Run(fmt.Sprintf("%d-to-%d", pi, ps), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ // Add a permanent address with initial
+ // PrimaryEndpointBehavior (peb), pi. If pi is
+ // NeverPrimaryEndpoint, the address should not
+ // be returned by a call to GetMainNICAddress;
+ // else, it should.
+ if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil {
+ t.Fatal("AddAddressWithOptions failed:", err)
+ }
+ addr, err := s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
+ t.Fatal("s.GetMainNICAddress failed:", err)
+ }
+ if pi == stack.NeverPrimaryEndpoint {
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want)
+
+ }
+ } else if addr.Address != "\x01" {
+ t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatalf("NewSubnet failed:", err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ // Take a route through the address so its ref
+ // count gets incremented and does not actually
+ // get deleted when RemoveAddress is called
+ // below. This is because we want to test that a
+ // new peb is respected when an address gets
+ // "promoted" to permanent from a
+ // permanentExpired kind.
+ r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ defer r.Release()
+ if err := s.RemoveAddress(1, "\x01"); err != nil {
+ t.Fatalf("RemoveAddress failed:", err)
+ }
+
+ //
+ // At this point, the address should still be
+ // known by the NIC, but have its
+ // kind = permanentExpired.
+ //
+
+ // Add some other address with peb set to
+ // FirstPrimaryEndpoint.
+ if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatal("AddAddressWithOptions failed:", err)
+
+ }
+
+ // Add back the address we removed earlier and
+ // make sure the new peb was respected.
+ // (The address should just be promoted now).
+ if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil {
+ t.Fatal("AddAddressWithOptions failed:", err)
+ }
+ var primaryAddrs []tcpip.Address
+ for _, pa := range s.NICInfo()[1].ProtocolAddresses {
+ primaryAddrs = append(primaryAddrs, pa.AddressWithPrefix.Address)
+ }
+ var expectedList []tcpip.Address
+ switch ps {
+ case stack.FirstPrimaryEndpoint:
+ expectedList = []tcpip.Address{
+ "\x01",
+ "\x03",
+ }
+ case stack.CanBePrimaryEndpoint:
+ expectedList = []tcpip.Address{
+ "\x03",
+ "\x01",
+ }
+ case stack.NeverPrimaryEndpoint:
+ expectedList = []tcpip.Address{
+ "\x03",
+ }
+ }
+ if !cmp.Equal(primaryAddrs, expectedList) {
+ t.Fatalf("got NIC's primary addresses = %v, want = %v", primaryAddrs, expectedList)
+ }
+
+ // Once we remove the other address, if the new
+ // peb, ps, was NeverPrimaryEndpoint, no address
+ // should be returned by a call to
+ // GetMainNICAddress; else, our original address
+ // should be returned.
+ if err := s.RemoveAddress(1, "\x03"); err != nil {
+ t.Fatalf("RemoveAddress failed:", err)
+ }
+ addr, err = s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
+ t.Fatal("s.GetMainNICAddress failed:", err)
+ }
+ if ps == stack.NeverPrimaryEndpoint {
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want)
+
+ }
+ } else {
+ if addr.Address != "\x01" {
+ t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address)
+ }
+ }
+ })
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 92267ce4d..67c21be42 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -17,10 +17,10 @@ package stack
import (
"fmt"
"math/rand"
+ "sort"
"sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -41,6 +41,31 @@ type transportEndpoints struct {
rawEndpoints []RawTransportEndpoint
}
+// unregisterEndpoint unregisters the endpoint with the given id such that it
+// won't receive any more packets.
+func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ epsByNic, ok := eps.endpoints[id]
+ if !ok {
+ return
+ }
+ if !epsByNic.unregisterEndpoint(bindToDevice, ep) {
+ return
+ }
+ delete(eps.endpoints, id)
+}
+
+func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint {
+ eps.mu.RLock()
+ defer eps.mu.RUnlock()
+ es := make([]TransportEndpoint, 0, len(eps.endpoints))
+ for _, e := range eps.endpoints {
+ es = append(es, e.transportEndpoints()...)
+ }
+ return es
+}
+
type endpointsByNic struct {
mu sync.RWMutex
endpoints map[tcpip.NICID]*multiPortEndpoint
@@ -48,9 +73,19 @@ type endpointsByNic struct {
seed uint32
}
+func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint {
+ epsByNic.mu.RLock()
+ defer epsByNic.mu.RUnlock()
+ var eps []TransportEndpoint
+ for _, ep := range epsByNic.endpoints {
+ eps = append(eps, ep.transportEndpoints()...)
+ }
+ return eps
+}
+
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) {
epsByNic.mu.RLock()
mpep, ok := epsByNic.endpoints[r.ref.nic.ID()]
@@ -64,18 +99,17 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, v
// If this is a broadcast or multicast datagram, deliver the datagram to all
// endpoints bound to the right device.
if isMulticastOrBroadcast(id.LocalAddress) {
- mpep.handlePacketAll(r, id, vv)
+ mpep.handlePacketAll(r, id, pkt)
epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
return
}
-
// multiPortEndpoints are guaranteed to have at least one element.
- selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, vv)
+ selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, pkt)
epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) {
+func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) {
epsByNic.mu.RLock()
defer epsByNic.mu.RUnlock()
@@ -91,7 +125,7 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint
// broadcast like we are doing with handlePacket above?
// multiPortEndpoints are guaranteed to have at least one element.
- selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, vv)
+ selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, pkt)
}
// registerEndpoint returns true if it succeeds. It fails and returns
@@ -127,21 +161,6 @@ func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t T
return len(epsByNic.endpoints) == 0
}
-// unregisterEndpoint unregisters the endpoint with the given id such that it
-// won't receive any more packets.
-func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
- eps.mu.Lock()
- defer eps.mu.Unlock()
- epsByNic, ok := eps.endpoints[id]
- if !ok {
- return
- }
- if !epsByNic.unregisterEndpoint(bindToDevice, ep) {
- return
- }
- delete(eps.endpoints, id)
-}
-
// transportDemuxer demultiplexes packets targeted at a transport endpoint
// (i.e., after they've been parsed by the network layer). It does two levels
// of demultiplexing: first based on the network and transport protocols, then
@@ -183,14 +202,27 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
// multiPortEndpoint is a container for TransportEndpoints which are bound to
// the same pair of address and port. endpointsArr always has at least one
// element.
+//
+// FIXME(gvisor.dev/issue/873): Restore this properly. Currently, we just save
+// this to ensure that the underlying endpoints get saved/restored, but not not
+// use the restored copy.
+//
+// +stateify savable
type multiPortEndpoint struct {
- mu sync.RWMutex
+ mu sync.RWMutex `state:"nosave"`
endpointsArr []TransportEndpoint
endpointsMap map[TransportEndpoint]int
// reuse indicates if more than one endpoint is allowed.
reuse bool
}
+func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint {
+ ep.mu.RLock()
+ eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
+ ep.mu.RUnlock()
+ return eps
+}
+
// reciprocalScale scales a value into range [0, n).
//
// This is similar to val % n, but faster.
@@ -224,22 +256,40 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
return mpep.endpointsArr[idx]
}
-func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) {
ep.mu.RLock()
for i, endpoint := range ep.endpointsArr {
- // HandlePacket modifies vv, so each endpoint needs its own copy except for
- // the final one.
+ // HandlePacket takes ownership of pkt, so each endpoint needs
+ // its own copy except for the final one.
if i == len(ep.endpointsArr)-1 {
- endpoint.HandlePacket(r, id, vv)
+ endpoint.HandlePacket(r, id, pkt)
break
}
- vvCopy := buffer.NewView(vv.Size())
- copy(vvCopy, vv.ToView())
- endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
+ endpoint.HandlePacket(r, id, pkt.Clone())
}
ep.mu.RUnlock() // Don't use defer for performance reasons.
}
+// Close implements stack.TransportEndpoint.Close.
+func (ep *multiPortEndpoint) Close() {
+ ep.mu.RLock()
+ eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
+ ep.mu.RUnlock()
+ for _, e := range eps {
+ e.Close()
+ }
+}
+
+// Wait implements stack.TransportEndpoint.Wait.
+func (ep *multiPortEndpoint) Wait() {
+ ep.mu.RLock()
+ eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
+ ep.mu.RUnlock()
+ for _, e := range eps {
+ e.Wait()
+ }
+}
+
// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
// list. The list might be empty already.
func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error {
@@ -257,6 +307,15 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePo
// endpointsMap. This will allow us to remove endpoint from the array fast.
ep.endpointsMap[t] = len(ep.endpointsArr)
ep.endpointsArr = append(ep.endpointsArr, t)
+
+ // ep.endpointsArr is sorted by endpoint unique IDs, so that endpoints
+ // can be restored in the same order.
+ sort.Slice(ep.endpointsArr, func(i, j int) bool {
+ return ep.endpointsArr[i].UniqueID() < ep.endpointsArr[j].UniqueID()
+ })
+ for i, e := range ep.endpointsArr {
+ ep.endpointsMap[e] = i
+ }
return nil
}
@@ -330,9 +389,9 @@ var loopbackSubnet = func() tcpip.Subnet {
}()
// deliverPacket attempts to find one or more matching transport endpoints, and
-// then, if matches are found, delivers the packet to them. Returns true if it
-// found one or more endpoints, false otherwise.
-func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView, id TransportEndpointID) bool {
+// then, if matches are found, delivers the packet to them. Returns true if
+// the packet no longer needs to be handled.
+func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
@@ -341,13 +400,38 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
eps.mu.RLock()
// Determine which transport endpoint or endpoints to deliver this packet to.
- // If the packet is a broadcast or multicast, then find all matching
- // transport endpoints.
+ // If the packet is a UDP broadcast or multicast, then find all matching
+ // transport endpoints. If the packet is a TCP packet with a non-unicast
+ // source or destination address, then do nothing further and instruct
+ // the caller to do the same.
var destEps []*endpointsByNic
- if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
- destEps = d.findAllEndpointsLocked(eps, vv, id)
- } else if ep := d.findEndpointLocked(eps, vv, id); ep != nil {
- destEps = append(destEps, ep)
+ switch protocol {
+ case header.UDPProtocolNumber:
+ if isMulticastOrBroadcast(id.LocalAddress) {
+ destEps = d.findAllEndpointsLocked(eps, id)
+ break
+ }
+
+ if ep := d.findEndpointLocked(eps, id); ep != nil {
+ destEps = append(destEps, ep)
+ }
+
+ case header.TCPProtocolNumber:
+ if !(isUnicast(r.LocalAddress) && isUnicast(r.RemoteAddress)) {
+ // TCP can only be used to communicate between a single
+ // source and a single destination; the addresses must
+ // be unicast.
+ eps.mu.RUnlock()
+ r.Stats().TCP.InvalidSegmentsReceived.Increment()
+ return true
+ }
+
+ fallthrough
+
+ default:
+ if ep := d.findEndpointLocked(eps, id); ep != nil {
+ destEps = append(destEps, ep)
+ }
}
eps.mu.RUnlock()
@@ -361,17 +445,19 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
return false
}
- // Deliver the packet.
- for _, ep := range destEps {
- ep.handlePacket(r, id, vv)
+ // HandlePacket takes ownership of pkt, so each endpoint needs its own
+ // copy except for the final one.
+ for _, ep := range destEps[:len(destEps)-1] {
+ ep.handlePacket(r, id, pkt.Clone())
}
+ destEps[len(destEps)-1].handlePacket(r, id, pkt)
return true
}
// deliverRawPacket attempts to deliver the given packet and returns whether it
// was delivered successfully.
-func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) bool {
+func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
@@ -385,7 +471,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
for _, rawEP := range eps.rawEndpoints {
// Each endpoint gets its own copy of the packet for the sake
// of save/restore.
- rawEP.HandlePacket(r, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView())
+ rawEP.HandlePacket(r, pkt)
foundRaw = true
}
eps.mu.RUnlock()
@@ -395,7 +481,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
// deliverControlPacket attempts to deliver the given control packet. Returns
// true if it found an endpoint, false otherwise.
-func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{net, trans}]
if !ok {
return false
@@ -403,7 +489,7 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco
// Try to find the endpoint.
eps.mu.RLock()
- ep := d.findEndpointLocked(eps, vv, id)
+ ep := d.findEndpointLocked(eps, id)
eps.mu.RUnlock()
// Fail if we didn't find one.
@@ -412,12 +498,12 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco
}
// Deliver the packet.
- ep.handleControlPacket(n, id, typ, extra, vv)
+ ep.handleControlPacket(n, id, typ, extra, pkt)
return true
}
-func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) []*endpointsByNic {
+func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic {
var matchedEPs []*endpointsByNic
// Try to find a match with the id as provided.
if ep, ok := eps.endpoints[id]; ok {
@@ -445,14 +531,44 @@ func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, vv bu
if ep, ok := eps.endpoints[nid]; ok {
matchedEPs = append(matchedEPs, ep)
}
-
return matchedEPs
}
+// findTransportEndpoint find a single endpoint that most closely matches the provided id.
+func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
+ eps, ok := d.protocol[protocolIDs{netProto, transProto}]
+ if !ok {
+ return nil
+ }
+ // Try to find the endpoint.
+ eps.mu.RLock()
+ epsByNic := d.findEndpointLocked(eps, id)
+ // Fail if we didn't find one.
+ if epsByNic == nil {
+ eps.mu.RUnlock()
+ return nil
+ }
+
+ epsByNic.mu.RLock()
+ eps.mu.RUnlock()
+
+ mpep, ok := epsByNic.endpoints[r.ref.nic.ID()]
+ if !ok {
+ if mpep, ok = epsByNic.endpoints[0]; !ok {
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ return nil
+ }
+ }
+
+ ep := selectEndpoint(id, mpep, epsByNic.seed)
+ epsByNic.mu.RUnlock()
+ return ep
+}
+
// findEndpointLocked returns the endpoint that most closely matches the given
// id.
-func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic {
- if matchedEPs := d.findAllEndpointsLocked(eps, vv, id); len(matchedEPs) > 0 {
+func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic {
+ if matchedEPs := d.findAllEndpointsLocked(eps, id); len(matchedEPs) > 0 {
return matchedEPs[0]
}
return nil
@@ -465,7 +581,7 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer
func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
eps, ok := d.protocol[protocolIDs{netProto, transProto}]
if !ok {
- return nil
+ return tcpip.ErrNotSupported
}
eps.mu.Lock()
@@ -496,3 +612,7 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
func isMulticastOrBroadcast(addr tcpip.Address) bool {
return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr)
}
+
+func isUnicast(addr tcpip.Address) bool {
+ return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr)
+}
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 210233dc0..3b28b06d0 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -79,17 +79,17 @@ func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string)
linkEPs := make(map[string]*channel.Endpoint)
for i, linkEpName := range linkEpNames {
channelEP := channel.New(256, mtu, "")
- nicid := tcpip.NICID(i + 1)
- if err := s.CreateNamedNIC(nicid, linkEpName, channelEP); err != nil {
+ nicID := tcpip.NICID(i + 1)
+ if err := s.CreateNamedNIC(nicID, linkEpName, channelEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
linkEPs[linkEpName] = channelEP
- if err := s.AddAddress(nicid, ipv4.ProtocolNumber, stackAddr); err != nil {
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
t.Fatalf("AddAddress IPv4 failed: %v", err)
}
- if err := s.AddAddress(nicid, ipv6.ProtocolNumber, stackV6Addr); err != nil {
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
t.Fatalf("AddAddress IPv6 failed: %v", err)
}
}
@@ -156,7 +156,9 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEPs[linkEpName].Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
+ c.linkEPs[linkEpName].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
}
func TestTransportDemuxerRegister(t *testing.T) {
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 86c62be25..748ce4ea5 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -43,6 +43,7 @@ type fakeTransportEndpoint struct {
proto *fakeTransportProtocol
peerAddr tcpip.Address
route stack.Route
+ uniqueID uint64
// acceptQueue is non-nil iff bound.
acceptQueue []fakeTransportEndpoint
@@ -56,8 +57,8 @@ func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats {
return nil
}
-func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint {
- return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto}
+func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
+ return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
}
func (f *fakeTransportEndpoint) Close() {
@@ -82,7 +83,10 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
if err != nil {
return 0, nil, err
}
- if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}); err != nil {
+ if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: buffer.View(v).ToVectorisedView(),
+ }); err != nil {
return 0, nil, err
}
@@ -144,6 +148,10 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return nil
}
+func (f *fakeTransportEndpoint) UniqueID() uint64 {
+ return f.uniqueID
+}
+
func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error {
return nil
}
@@ -192,7 +200,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro
return tcpip.FullAddress{}, nil
}
-func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) {
+func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ tcpip.PacketBuffer) {
// Increment the number of received packets.
f.proto.packetCount++
if f.acceptQueue != nil {
@@ -209,7 +217,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE
}
}
-func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, buffer.VectorisedView) {
+func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, tcpip.PacketBuffer) {
// Increment the number of received control packets.
f.proto.controlCount++
}
@@ -218,15 +226,15 @@ func (f *fakeTransportEndpoint) State() uint32 {
return 0
}
-func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {
-}
+func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {}
func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) {
return iptables.IPTables{}, nil
}
-func (f *fakeTransportEndpoint) Resume(*stack.Stack) {
-}
+func (f *fakeTransportEndpoint) Resume(*stack.Stack) {}
+
+func (f *fakeTransportEndpoint) Wait() {}
type fakeTransportGoodOption bool
@@ -251,7 +259,7 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber {
}
func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newFakeTransportEndpoint(stack, f, netProto), nil
+ return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil
}
func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
@@ -266,7 +274,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp
return 0, 0, nil
}
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool {
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool {
return true
}
@@ -337,7 +345,9 @@ func TestTransportReceive(t *testing.T) {
// Make sure packet with wrong protocol is not delivered.
buf[0] = 1
buf[2] = 0
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if fakeTrans.packetCount != 0 {
t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
}
@@ -346,7 +356,9 @@ func TestTransportReceive(t *testing.T) {
buf[0] = 1
buf[1] = 3
buf[2] = byte(fakeTransNumber)
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if fakeTrans.packetCount != 0 {
t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
}
@@ -355,7 +367,9 @@ func TestTransportReceive(t *testing.T) {
buf[0] = 1
buf[1] = 2
buf[2] = byte(fakeTransNumber)
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if fakeTrans.packetCount != 1 {
t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1)
}
@@ -408,7 +422,9 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 0
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = 0
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if fakeTrans.controlCount != 0 {
t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
}
@@ -417,7 +433,9 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 3
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if fakeTrans.controlCount != 0 {
t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
}
@@ -426,7 +444,9 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 2
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if fakeTrans.controlCount != 1 {
t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1)
}
@@ -579,7 +599,9 @@ func TestTransportForwarding(t *testing.T) {
req[0] = 1
req[1] = 3
req[2] = byte(fakeTransNumber)
- ep2.Inject(fakeNetNumber, req.ToVectorisedView())
+ ep2.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: req.ToVectorisedView(),
+ })
aep, _, err := ep.Accept()
if err != nil || aep == nil {
@@ -598,10 +620,10 @@ func TestTransportForwarding(t *testing.T) {
t.Fatal("Response packet not forwarded")
}
- if dst := p.Header[0]; dst != 3 {
+ if dst := p.Pkt.Header.View()[0]; dst != 3 {
t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst)
}
- if src := p.Header[1]; src != 1 {
+ if src := p.Pkt.Header.View()[1]; src != 1 {
t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
}
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 678a94616..f62fd729f 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -231,6 +231,13 @@ func (s *Subnet) Broadcast() Address {
return Address(addr)
}
+// Equal returns true if s equals o.
+//
+// Needed to use cmp.Equal on Subnet as its fields are unexported.
+func (s Subnet) Equal(o Subnet) bool {
+ return s == o
+}
+
// NICID is a number that uniquely identifies a NIC.
type NICID int32
@@ -255,7 +262,7 @@ type FullAddress struct {
// This may not be used by all endpoint types.
NIC NICID
- // Addr is the network address.
+ // Addr is the network or link layer address.
Addr Address
// Port is the transport port.
@@ -301,7 +308,7 @@ type ControlMessages struct {
// HasTimestamp indicates whether Timestamp is valid/set.
HasTimestamp bool
- // Timestamp is the time (in ns) that the last packed used to create
+ // Timestamp is the time (in ns) that the last packet used to create
// the read data was received.
Timestamp int64
@@ -310,6 +317,18 @@ type ControlMessages struct {
// Inq is the number of bytes ready to be received.
Inq int32
+
+ // HasTOS indicates whether Tos is valid/set.
+ HasTOS bool
+
+ // TOS is the IPv4 type of service of the associated packet.
+ TOS int8
+
+ // HasTClass indicates whether Tclass is valid/set.
+ HasTClass bool
+
+ // Tclass is the IPv6 traffic class of the associated packet.
+ TClass int32
}
// Endpoint is the interface implemented by transport protocols (e.g., tcp, udp)
@@ -489,6 +508,11 @@ const (
// number of unread bytes in the output buffer should be returned.
SendQueueSizeOption
+ // DelayOption is used by SetSockOpt/GetSockOpt to specify if data
+ // should be sent out immediately by the transport protocol. For TCP,
+ // it determines if the Nagle algorithm is on or off.
+ DelayOption
+
// TODO(b/137664753): convert all int socket options to be handled via
// GetSockOptInt.
)
@@ -501,11 +525,6 @@ type ErrorOption struct{}
// socket is to be restricted to sending and receiving IPv6 packets only.
type V6OnlyOption int
-// DelayOption is used by SetSockOpt/GetSockOpt to specify if data should be
-// sent out immediately by the transport protocol. For TCP, it determines if the
-// Nagle algorithm is on or off.
-type DelayOption int
-
// CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be
// held until segments are full by the TCP transport protocol.
type CorkOption int
@@ -557,6 +576,11 @@ type KeepaliveIntervalOption time.Duration
// closed.
type KeepaliveCountOption int
+// TCPUserTimeoutOption is used by SetSockOpt/GetSockOpt to specify a user
+// specified timeout for a given TCP connection.
+// See: RFC5482 for details.
+type TCPUserTimeoutOption time.Duration
+
// CongestionControlOption is used by SetSockOpt/GetSockOpt to set/get
// the current congestion control algorithm.
type CongestionControlOption string
@@ -579,6 +603,16 @@ type MaxSegOption int
// A zero value indicates the default.
type TTLOption uint8
+// TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the
+// maximum duration for which a socket lingers in the TCP_FIN_WAIT_2 state
+// before being marked closed.
+type TCPLingerTimeoutOption time.Duration
+
+// TCPTimeWaitTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the
+// maximum duration for which a socket lingers in the TIME_WAIT state
+// before being marked closed.
+type TCPTimeWaitTimeoutOption time.Duration
+
// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
// TTL value for multicast messages. The default is 1.
type MulticastTTLOption uint8
@@ -673,6 +707,11 @@ func (s *StatCounter) Increment() {
s.IncrementBy(1)
}
+// Decrement minuses one to the counter.
+func (s *StatCounter) Decrement() {
+ s.IncrementBy(^uint64(0))
+}
+
// Value returns the current value of the counter.
func (s *StatCounter) Value() uint64 {
return atomic.LoadUint64(&s.count)
@@ -881,6 +920,23 @@ type TCPStats struct {
// successfully via Listen.
PassiveConnectionOpenings *StatCounter
+ // CurrentEstablished is the number of TCP connections for which the
+ // current state is either ESTABLISHED or CLOSE-WAIT.
+ CurrentEstablished *StatCounter
+
+ // EstablishedResets is the number of times TCP connections have made
+ // a direct transition to the CLOSED state from either the
+ // ESTABLISHED state or the CLOSE-WAIT state.
+ EstablishedResets *StatCounter
+
+ // EstablishedClosed is the number of times established TCP connections
+ // made a transition to CLOSED state.
+ EstablishedClosed *StatCounter
+
+ // EstablishedTimedout is the number of times an established connection
+ // was reset because of keep-alive time out.
+ EstablishedTimedout *StatCounter
+
// ListenOverflowSynDrop is the number of times the listen queue overflowed
// and a SYN was dropped.
ListenOverflowSynDrop *StatCounter
@@ -1308,8 +1364,8 @@ var (
// GetDanglingEndpoints returns all dangling endpoints.
func GetDanglingEndpoints() []Endpoint {
- es := make([]Endpoint, 0, len(danglingEndpoints))
danglingEndpointsMu.Lock()
+ es := make([]Endpoint, 0, len(danglingEndpoints))
for e := range danglingEndpoints {
es = append(es, e)
}
diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go
index a52262e87..48764b978 100644
--- a/pkg/tcpip/time_unsafe.go
+++ b/pkg/tcpip/time_unsafe.go
@@ -13,7 +13,7 @@
// limitations under the License.
// +build go1.9
-// +build !go1.14
+// +build !go1.15
// Check go:linkname function signatures when updating Go version.
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index 9254c3dea..d8c5b5058 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -38,11 +38,3 @@ go_library(
"//pkg/waiter",
],
)
-
-filegroup(
- name = "autogen",
- srcs = [
- "icmp_packet_list.go",
- ],
- visibility = ["//:sandbox"],
-)
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 3187b336b..9c40931b5 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -31,9 +31,6 @@ type icmpPacket struct {
senderAddress tcpip.FullAddress
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
timestamp int64
- // views is used as buffer for data when its length is large
- // enough to store a VectorisedView.
- views [8]buffer.View `state:"nosave"`
}
type endpointState int
@@ -58,6 +55,7 @@ type endpoint struct {
// immutable.
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
+ uniqueID uint64
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -90,9 +88,15 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
state: stateInitial,
+ uniqueID: s.UniqueID(),
}, nil
}
+// UniqueID implements stack.TransportEndpoint.UniqueID.
+func (e *endpoint) UniqueID() uint64 {
+ return e.uniqueID
+}
+
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
@@ -274,13 +278,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
} else {
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
- nicid := to.NIC
+ nicID := to.NIC
if e.BindNICID != 0 {
- if nicid != 0 && nicid != e.BindNICID {
+ if nicID != 0 && nicID != e.BindNICID {
return 0, nil, tcpip.ErrNoRoute
}
- nicid = e.BindNICID
+ nicID = e.BindNICID
}
toCopy := *to
@@ -291,7 +295,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
// Find the enpoint.
- r, err := e.stack.FindRoute(nicid, e.BindAddr, to.Addr, netProto, false /* multicastLoop */)
+ r, err := e.stack.FindRoute(nicID, e.BindAddr, to.Addr, netProto, false /* multicastLoop */)
if err != nil {
return 0, nil, err
}
@@ -425,7 +429,11 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS})
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: data.ToVectorisedView(),
+ TransportHeader: buffer.View(icmpv4),
+ })
}
func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error {
@@ -445,13 +453,17 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
return tcpip.ErrInvalidEndpointState
}
- icmpv6.SetChecksum(0)
- icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0)))
+ dataVV := data.ToVectorisedView()
+ icmpv6.SetChecksum(header.ICMPv6Checksum(icmpv6, r.LocalAddress, r.RemoteAddress, dataVV))
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS})
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: dataVV,
+ TransportHeader: buffer.View(icmpv6),
+ })
}
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
@@ -479,7 +491,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- nicid := addr.NIC
+ nicID := addr.NIC
localPort := uint16(0)
switch e.state {
case stateBound, stateConnected:
@@ -488,11 +500,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
break
}
- if nicid != 0 && nicid != e.BindNICID {
+ if nicID != 0 && nicID != e.BindNICID {
return tcpip.ErrInvalidEndpointState
}
- nicid = e.BindNICID
+ nicID = e.BindNICID
default:
return tcpip.ErrInvalidEndpointState
}
@@ -503,7 +515,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicid, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */)
+ r, err := e.stack.FindRoute(nicID, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */)
if err != nil {
return err
}
@@ -520,14 +532,14 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// v6only is set to false and this is an ipv6 endpoint.
netProtos := []tcpip.NetworkProtocolNumber{netProto}
- id, err = e.registerWithStack(nicid, netProtos, id)
+ id, err = e.registerWithStack(nicID, netProtos, id)
if err != nil {
return err
}
e.ID = id
e.route = r.Clone()
- e.RegisterNICID = nicid
+ e.RegisterNICID = nicID
e.state = stateConnected
@@ -578,18 +590,18 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
-func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
+func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */)
return id, err
}
// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */)
switch err {
case nil:
return true, nil
@@ -714,18 +726,18 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
// Only accept echo replies.
switch e.NetProto {
case header.IPv4ProtocolNumber:
- h := header.ICMPv4(vv.First())
+ h := header.ICMPv4(pkt.Data.First())
if h.Type() != header.ICMPv4EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
case header.IPv6ProtocolNumber:
- h := header.ICMPv6(vv.First())
+ h := header.ICMPv6(pkt.Data.First())
if h.Type() != header.ICMPv6EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
@@ -753,19 +765,19 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
wasEmpty := e.rcvBufSize == 0
// Push new packet into receive list and increment the buffer size.
- pkt := &icmpPacket{
+ packet := &icmpPacket{
senderAddress: tcpip.FullAddress{
NIC: r.NICID(),
Addr: id.RemoteAddress,
},
}
- pkt.data = vv.Clone(pkt.views[:])
+ packet.data = pkt.Data
- e.rcvList.PushBack(pkt)
- e.rcvBufSize += pkt.data.Size()
+ e.rcvList.PushBack(packet)
+ e.rcvBufSize += packet.data.Size()
- pkt.timestamp = e.stack.NowNanoseconds()
+ packet.timestamp = e.stack.NowNanoseconds()
e.rcvMu.Unlock()
e.stats.PacketsReceived.Increment()
@@ -776,7 +788,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
}
// State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't
@@ -798,3 +810,6 @@ func (e *endpoint) Info() tcpip.EndpointInfo {
func (e *endpoint) Stats() tcpip.EndpointStats {
return &e.stats
}
+
+// Wait implements stack.TransportEndpoint.Wait.
+func (*endpoint) Wait() {}
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index bfb16f7c3..9ce500e80 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -104,7 +104,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool {
+func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool {
return true
}
diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD
new file mode 100644
index 000000000..44b58ff6b
--- /dev/null
+++ b/pkg/tcpip/transport/packet/BUILD
@@ -0,0 +1,38 @@
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "packet_list",
+ out = "packet_list.go",
+ package = "packet",
+ prefix = "packet",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*packet",
+ "Linker": "*packet",
+ },
+)
+
+go_library(
+ name = "packet",
+ srcs = [
+ "endpoint.go",
+ "endpoint_state.go",
+ "packet_list.go",
+ ],
+ importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/packet",
+ imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/log",
+ "//pkg/sleep",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/iptables",
+ "//pkg/tcpip/stack",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
new file mode 100644
index 000000000..0010b5e5f
--- /dev/null
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -0,0 +1,362 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package packet provides the implementation of packet sockets (see
+// packet(7)). Packet sockets allow applications to:
+//
+// * manually write and inspect link, network, and transport headers
+// * receive all traffic of a given network protocol, or all protocols
+//
+// Packet sockets are similar to raw sockets, but provide even more power to
+// users, letting them effectively talk directly to the network device.
+//
+// Packet sockets skip the input and output iptables chains.
+package packet
+
+import (
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// +stateify savable
+type packet struct {
+ packetEntry
+ // data holds the actual packet data, including any headers and
+ // payload.
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ // timestampNS is the unix time at which the packet was received.
+ timestampNS int64
+ // senderAddr is the network address of the sender.
+ senderAddr tcpip.FullAddress
+}
+
+// endpoint is the packet socket implementation of tcpip.Endpoint. It is legal
+// to have goroutines make concurrent calls into the endpoint.
+//
+// Lock order:
+// endpoint.mu
+// endpoint.rcvMu
+//
+// +stateify savable
+type endpoint struct {
+ stack.TransportEndpointInfo
+ // The following fields are initialized at creation time and are
+ // immutable.
+ stack *stack.Stack `state:"manual"`
+ netProto tcpip.NetworkProtocolNumber
+ waiterQueue *waiter.Queue
+ cooked bool
+
+ // The following fields are used to manage the receive queue and are
+ // protected by rcvMu.
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvList packetList
+ rcvBufSizeMax int `state:".(int)"`
+ rcvBufSize int
+ rcvClosed bool
+
+ // The following fields are protected by mu.
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ closed bool
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+}
+
+// NewEndpoint returns a new packet endpoint.
+func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ ep := &endpoint{
+ stack: s,
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ },
+ cooked: cooked,
+ netProto: netProto,
+ waiterQueue: waiterQueue,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSize: 32 * 1024,
+ }
+
+ if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil {
+ return nil, err
+ }
+ return ep, nil
+}
+
+// Close implements tcpip.Endpoint.Close.
+func (ep *endpoint) Close() {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ if ep.closed {
+ return
+ }
+
+ ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+
+ ep.rcvMu.Lock()
+ defer ep.rcvMu.Unlock()
+
+ // Clear the receive list.
+ ep.rcvClosed = true
+ ep.rcvBufSize = 0
+ for !ep.rcvList.Empty() {
+ ep.rcvList.Remove(ep.rcvList.Front())
+ }
+
+ ep.closed = true
+ ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+}
+
+// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
+func (ep *endpoint) ModerateRecvBuf(copied int) {}
+
+// IPTables implements tcpip.Endpoint.IPTables.
+func (ep *endpoint) IPTables() (iptables.IPTables, error) {
+ return ep.stack.IPTables(), nil
+}
+
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ ep.rcvMu.Lock()
+
+ // If there's no data to read, return that read would block or that the
+ // endpoint is closed.
+ if ep.rcvList.Empty() {
+ err := tcpip.ErrWouldBlock
+ if ep.rcvClosed {
+ ep.stats.ReadErrors.ReadClosed.Increment()
+ err = tcpip.ErrClosedForReceive
+ }
+ ep.rcvMu.Unlock()
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+
+ packet := ep.rcvList.Front()
+ ep.rcvList.Remove(packet)
+ ep.rcvBufSize -= packet.data.Size()
+
+ ep.rcvMu.Unlock()
+
+ if addr != nil {
+ *addr = packet.senderAddr
+ }
+
+ return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
+}
+
+func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ // TODO(b/129292371): Implement.
+ return 0, nil, tcpip.ErrInvalidOptionValue
+}
+
+// Peek implements tcpip.Endpoint.Peek.
+func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
+ return 0, tcpip.ControlMessages{}, nil
+}
+
+// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be
+// disconnected, and this function always returns tpcip.ErrNotSupported.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be
+// connected, and this function always returnes tcpip.ErrNotSupported.
+func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used
+// with Shutdown, and this function always returns tcpip.ErrNotSupported.
+func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with
+// Listen, and this function always returns tcpip.ErrNotSupported.
+func (ep *endpoint) Listen(backlog int) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with
+// Accept, and this function always returns tcpip.ErrNotSupported.
+func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, tcpip.ErrNotSupported
+}
+
+// Bind implements tcpip.Endpoint.Bind.
+func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+ // TODO(gvisor.dev/issue/173): Add Bind support.
+
+ // "By default, all packets of the specified protocol type are passed
+ // to a packet socket. To get packets only from a specific interface
+ // use bind(2) specifying an address in a struct sockaddr_ll to bind
+ // the packet socket to an interface. Fields used for binding are
+ // sll_family (should be AF_PACKET), sll_protocol, and sll_ifindex."
+ // - packet(7).
+
+ return tcpip.ErrNotSupported
+}
+
+// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
+func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{}, tcpip.ErrNotSupported
+}
+
+// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
+func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ // Even a connected socket doesn't return a remote address.
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+}
+
+// Readiness implements tcpip.Endpoint.Readiness.
+func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // The endpoint is always writable.
+ result := waiter.EventOut & mask
+
+ // Determine whether the endpoint is readable.
+ if (mask & waiter.EventIn) != 0 {
+ ep.rcvMu.Lock()
+ if !ep.rcvList.Empty() || ep.rcvClosed {
+ result |= waiter.EventIn
+ }
+ ep.rcvMu.Unlock()
+ }
+
+ return result
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt. Packet sockets cannot be
+// used with SetSockOpt, and this function always returns
+// tcpip.ErrNotSupported.
+func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+ return 0, tcpip.ErrNotSupported
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// HandlePacket implements stack.PacketEndpoint.HandlePacket.
+func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+ ep.rcvMu.Lock()
+
+ // Drop the packet if our buffer is currently full.
+ if ep.rcvClosed {
+ ep.rcvMu.Unlock()
+ ep.stack.Stats().DroppedPackets.Increment()
+ ep.stats.ReceiveErrors.ClosedReceiver.Increment()
+ return
+ }
+
+ if ep.rcvBufSize >= ep.rcvBufSizeMax {
+ ep.rcvMu.Unlock()
+ ep.stack.Stats().DroppedPackets.Increment()
+ ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
+ return
+ }
+
+ wasEmpty := ep.rcvBufSize == 0
+
+ // Push new packet into receive list and increment the buffer size.
+ var packet packet
+ // TODO(b/129292371): Return network protocol.
+ if len(pkt.LinkHeader) > 0 {
+ // Get info directly from the ethernet header.
+ hdr := header.Ethernet(pkt.LinkHeader)
+ packet.senderAddr = tcpip.FullAddress{
+ NIC: nicID,
+ Addr: tcpip.Address(hdr.SourceAddress()),
+ }
+ } else {
+ // Guess the would-be ethernet header.
+ packet.senderAddr = tcpip.FullAddress{
+ NIC: nicID,
+ Addr: tcpip.Address(localAddr),
+ }
+ }
+
+ if ep.cooked {
+ // Cooked packets can simply be queued.
+ packet.data = pkt.Data
+ } else {
+ // Raw packets need their ethernet headers prepended before
+ // queueing.
+ var linkHeader buffer.View
+ if len(pkt.LinkHeader) == 0 {
+ // We weren't provided with an actual ethernet header,
+ // so fake one.
+ ethFields := header.EthernetFields{
+ SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
+ DstAddr: localAddr,
+ Type: netProto,
+ }
+ fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
+ fakeHeader.Encode(&ethFields)
+ linkHeader = buffer.View(fakeHeader)
+ } else {
+ linkHeader = append(buffer.View(nil), pkt.LinkHeader...)
+ }
+ combinedVV := linkHeader.ToVectorisedView()
+ combinedVV.Append(pkt.Data)
+ packet.data = combinedVV
+ }
+ packet.timestampNS = ep.stack.NowNanoseconds()
+
+ ep.rcvList.PushBack(&packet)
+ ep.rcvBufSize += packet.data.Size()
+
+ ep.rcvMu.Unlock()
+ ep.stats.PacketsReceived.Increment()
+ // Notify waiters that there's data to be read.
+ if wasEmpty {
+ ep.waiterQueue.Notify(waiter.EventIn)
+ }
+}
+
+// State implements socket.Socket.State.
+func (ep *endpoint) State() uint32 {
+ return 0
+}
+
+// Info returns a copy of the endpoint info.
+func (ep *endpoint) Info() tcpip.EndpointInfo {
+ ep.mu.RLock()
+ // Make a copy of the endpoint info.
+ ret := ep.TransportEndpointInfo
+ ep.mu.RUnlock()
+ return &ret
+}
+
+// Stats returns a pointer to the endpoint stats.
+func (ep *endpoint) Stats() tcpip.EndpointStats {
+ return &ep.stats
+}
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
new file mode 100644
index 000000000..9b88f17e4
--- /dev/null
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -0,0 +1,72 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package packet
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// saveData saves packet.data field.
+func (p *packet) saveData() buffer.VectorisedView {
+ // We cannot save p.data directly as p.data.views may alias to p.views,
+ // which is not allowed by state framework (in-struct pointer).
+ return p.data.Clone(nil)
+}
+
+// loadData loads packet.data field.
+func (p *packet) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing p.views for data.views.
+ p.data = data
+}
+
+// beforeSave is invoked by stateify.
+func (ep *endpoint) beforeSave() {
+ // Stop incoming packets from being handled (and mutate endpoint state).
+ // The lock will be released after saveRcvBufSizeMax(), which would have
+ // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // packets.
+ ep.rcvMu.Lock()
+}
+
+// saveRcvBufSizeMax is invoked by stateify.
+func (ep *endpoint) saveRcvBufSizeMax() int {
+ max := ep.rcvBufSizeMax
+ // Make sure no new packets will be handled regardless of the lock.
+ ep.rcvBufSizeMax = 0
+ // Release the lock acquired in beforeSave() so regular endpoint closing
+ // logic can proceed after save.
+ ep.rcvMu.Unlock()
+ return max
+}
+
+// loadRcvBufSizeMax is invoked by stateify.
+func (ep *endpoint) loadRcvBufSizeMax(max int) {
+ ep.rcvBufSizeMax = max
+}
+
+// afterLoad is invoked by stateify.
+func (ep *endpoint) afterLoad() {
+ // StackFromEnv is a stack used specifically for save/restore.
+ ep.stack = stack.StackFromEnv
+
+ // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC.
+ if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil {
+ panic(*err)
+ }
+}
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index fba598d51..00991ac8e 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -4,14 +4,14 @@ load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
go_template_instance(
- name = "packet_list",
- out = "packet_list.go",
+ name = "raw_packet_list",
+ out = "raw_packet_list.go",
package = "raw",
- prefix = "packet",
+ prefix = "rawPacket",
template = "//pkg/ilist:generic_list",
types = {
- "Element": "*packet",
- "Linker": "*packet",
+ "Element": "*rawPacket",
+ "Linker": "*rawPacket",
},
)
@@ -20,8 +20,8 @@ go_library(
srcs = [
"endpoint.go",
"endpoint_state.go",
- "packet_list.go",
"protocol.go",
+ "raw_packet_list.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
@@ -34,14 +34,7 @@ go_library(
"//pkg/tcpip/header",
"//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/packet",
"//pkg/waiter",
],
)
-
-filegroup(
- name = "autogen",
- srcs = [
- "packet_list.go",
- ],
- visibility = ["//:sandbox"],
-)
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index b4c660859..5aafe2615 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -17,8 +17,7 @@
//
// * manually write and inspect transport layer headers and payloads
// * receive all traffic of a given transport protocol (e.g. ICMP or UDP)
-// * optionally write and inspect network layer and link layer headers for
-// packets
+// * optionally write and inspect network layer headers of packets
//
// Raw sockets don't have any notion of ports, and incoming packets are
// demultiplexed solely by protocol number. Thus, a raw UDP endpoint will
@@ -38,15 +37,11 @@ import (
)
// +stateify savable
-type packet struct {
- packetEntry
+type rawPacket struct {
+ rawPacketEntry
// data holds the actual packet data, including any headers and
// payload.
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- // views is pre-allocated space to back data. As long as the packet is
- // made up of fewer than 8 buffer.Views, no extra allocation is
- // necessary to store packet data.
- views [8]buffer.View `state:"nosave"`
// timestampNS is the unix time at which the packet was received.
timestampNS int64
// senderAddr is the network address of the sender.
@@ -72,7 +67,7 @@ type endpoint struct {
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
rcvMu sync.Mutex `state:"nosave"`
- rcvList packetList
+ rcvList rawPacketList
rcvBufSizeMax int `state:".(int)"`
rcvBufSize int
rcvClosed bool
@@ -90,7 +85,6 @@ type endpoint struct {
}
// NewEndpoint returns a raw endpoint for the given protocols.
-// TODO(b/129292371): IP_HDRINCL and AF_PACKET.
func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */)
}
@@ -187,17 +181,17 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
return buffer.View{}, tcpip.ControlMessages{}, err
}
- packet := e.rcvList.Front()
- e.rcvList.Remove(packet)
- e.rcvBufSize -= packet.data.Size()
+ pkt := e.rcvList.Front()
+ e.rcvList.Remove(pkt)
+ e.rcvBufSize -= pkt.data.Size()
e.rcvMu.Unlock()
if addr != nil {
- *addr = packet.senderAddr
+ *addr = pkt.senderAddr
}
- return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
+ return pkt.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: pkt.timestampNS}, nil
}
// Write implements tcpip.Endpoint.Write.
@@ -344,13 +338,18 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
switch e.NetProto {
case header.IPv4ProtocolNumber:
if !e.associated {
- if err := route.WriteHeaderIncludedPacket(buffer.View(payloadBytes).ToVectorisedView()); err != nil {
+ if err := route.WriteHeaderIncludedPacket(tcpip.PacketBuffer{
+ Data: buffer.View(payloadBytes).ToVectorisedView(),
+ }); err != nil {
return 0, nil, err
}
break
}
hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
- if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil {
+ if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: buffer.View(payloadBytes).ToVectorisedView(),
+ }); err != nil {
return 0, nil, err
}
@@ -561,7 +560,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
-func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(route *stack.Route, pkt tcpip.PacketBuffer) {
e.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
@@ -602,16 +601,17 @@ func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv bu
wasEmpty := e.rcvBufSize == 0
// Push new packet into receive list and increment the buffer size.
- packet := &packet{
+ packet := &rawPacket{
senderAddr: tcpip.FullAddress{
NIC: route.NICID(),
Addr: route.RemoteAddress,
},
}
- combinedVV := netHeader.ToVectorisedView()
- combinedVV.Append(vv)
- packet.data = combinedVV.Clone(packet.views[:])
+ networkHeader := append(buffer.View(nil), pkt.NetworkHeader...)
+ combinedVV := networkHeader.ToVectorisedView()
+ combinedVV.Append(pkt.Data)
+ packet.data = combinedVV
packet.timestampNS = e.stack.NowNanoseconds()
e.rcvList.PushBack(packet)
@@ -643,3 +643,6 @@ func (e *endpoint) Info() tcpip.EndpointInfo {
func (e *endpoint) Stats() tcpip.EndpointStats {
return &e.stats
}
+
+// Wait implements stack.TransportEndpoint.Wait.
+func (*endpoint) Wait() {}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index a6c7cc43a..33bfb56cd 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -20,15 +20,15 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
-// saveData saves packet.data field.
-func (p *packet) saveData() buffer.VectorisedView {
+// saveData saves rawPacket.data field.
+func (p *rawPacket) saveData() buffer.VectorisedView {
// We cannot save p.data directly as p.data.views may alias to p.views,
// which is not allowed by state framework (in-struct pointer).
return p.data.Clone(nil)
}
-// loadData loads packet.data field.
-func (p *packet) loadData(data buffer.VectorisedView) {
+// loadData loads rawPacket.data field.
+func (p *rawPacket) loadData(data buffer.VectorisedView) {
// NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
// here because data.views is not guaranteed to be loaded by now. Plus,
// data.views will be allocated anyway so there really is little point
@@ -86,7 +86,9 @@ func (ep *endpoint) Resume(s *stack.Stack) {
}
}
- if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil {
- panic(err)
+ if ep.associated {
+ if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil {
+ panic(err)
+ }
}
}
diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go
index a2512d666..f30aa2a4a 100644
--- a/pkg/tcpip/transport/raw/protocol.go
+++ b/pkg/tcpip/transport/raw/protocol.go
@@ -17,13 +17,19 @@ package raw
import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/packet"
"gvisor.dev/gvisor/pkg/waiter"
)
-// EndpointFactory implements stack.UnassociatedEndpointFactory.
+// EndpointFactory implements stack.RawFactory.
type EndpointFactory struct{}
-// NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory.
-func (EndpointFactory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+// NewUnassociatedEndpoint implements stack.RawFactory.NewUnassociatedEndpoint.
+func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */)
}
+
+// NewPacketEndpoint implements stack.RawFactory.NewPacketEndpoint.
+func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return packet.NewEndpoint(stack, cooked, netProto, waiterQueue)
+}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index aed70e06f..3b353d56c 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -28,6 +28,7 @@ go_library(
"forwarder.go",
"protocol.go",
"rcv.go",
+ "rcv_state.go",
"reno.go",
"sack.go",
"sack_scoreboard.go",
@@ -44,6 +45,7 @@ go_library(
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/log",
"//pkg/rand",
"//pkg/sleep",
"//pkg/tcpip",
@@ -51,6 +53,7 @@ go_library(
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
"//pkg/tcpip/iptables",
+ "//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
@@ -60,17 +63,9 @@ go_library(
],
)
-filegroup(
- name = "autogen",
- srcs = [
- "tcp_segment_list.go",
- ],
- visibility = ["//:sandbox"],
-)
-
go_test(
name = "tcp_test",
- size = "small",
+ size = "medium",
srcs = [
"dual_stack_test.go",
"sack_scoreboard_test.go",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 844959fa0..5422ae80c 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -241,8 +242,15 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.initGSO()
+ // Now inherit any socket options that should be inherited from the
+ // listening endpoint.
+ // In case of Forwarder listenEP will be nil and hence this check.
+ if l.listenEP != nil {
+ l.listenEP.propagateInheritableOptions(n)
+ }
+
// Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.bindToDevice); err != nil {
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil {
n.Close()
return nil, err
}
@@ -268,8 +276,8 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
- cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS))
- ep, err := l.createConnectingEndpoint(s, cookie, irs, opts)
+ isn := generateSecureISN(s.id, l.stack.Seed())
+ ep, err := l.createConnectingEndpoint(s, isn, irs, opts)
if err != nil {
return nil, err
}
@@ -288,7 +296,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// Perform the 3-way handshake.
h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow()))
- h.resetToSynRcvd(cookie, irs, opts)
+ h.resetToSynRcvd(isn, irs, opts)
if err := h.execute(); err != nil {
ep.Close()
if l.listenEP != nil {
@@ -297,7 +305,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
return nil, err
}
ep.mu.Lock()
- ep.state = StateEstablished
+ ep.isConnectNotified = true
ep.mu.Unlock()
// Update the receive window scaling. We can't do it before the
@@ -349,6 +357,14 @@ func (e *endpoint) deliverAccepted(n *endpoint) {
}
}
+// propagateInheritableOptions propagates any options set on the listening
+// endpoint to the newly created endpoint.
+func (e *endpoint) propagateInheritableOptions(n *endpoint) {
+ e.mu.Lock()
+ n.userTimeout = e.userTimeout
+ e.mu.Unlock()
+}
+
// handleSynSegment is called in its own goroutine once the listening endpoint
// receives a SYN segment. It is responsible for completing the handshake and
// queueing the new endpoint for acceptance.
@@ -359,6 +375,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
defer decSynRcvdCount()
defer e.decSynRcvdCount()
defer s.decRef()
+
n, err := ctx.createEndpointAndPerformHandshake(s, opts)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
@@ -366,6 +383,11 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
return
}
ctx.removePendingEndpoint(n)
+ // Start the protocol goroutine.
+ wq := &waiter.Queue{}
+ n.startAcceptedLoop(wq)
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
+
e.deliverAccepted(n)
}
@@ -399,8 +421,19 @@ func (e *endpoint) acceptQueueIsFull() bool {
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
- switch s.flags {
- case header.TCPFlagSyn:
+ if s.flagsAreSet(header.TCPFlagSyn | header.TCPFlagAck) {
+ // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST
+ // must be sent in response to a SYN-ACK while in the listen
+ // state to prevent completing a handshake from an old SYN.
+ e.sendTCP(&s.route, s.id, buffer.VectorisedView{}, e.ttl, e.sendTOS, header.TCPFlagRst, s.ackNumber, 0, 0, nil, nil)
+ return
+ }
+
+ // TODO(b/143300739): Use the userMSS of the listening socket
+ // for accepted sockets.
+
+ switch {
+ case s.flags == header.TCPFlagSyn:
opts := parseSynSegmentOptions(s)
if incSynRcvdCount() {
// Only handle the syn if the following conditions hold
@@ -433,19 +466,18 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
//
// Enable Timestamp option if the original syn did have
// the timestamp option specified.
- mss := mssForRoute(&s.route)
synOpts := header.TCPSynOptions{
WS: -1,
TS: opts.TS,
TSVal: tcpTimeStamp(timeStampOffset()),
TSEcr: opts.TSVal,
- MSS: uint16(mss),
+ MSS: mssForRoute(&s.route),
}
e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts)
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
}
- case header.TCPFlagAck:
+ case (s.flags & header.TCPFlagAck) != 0:
if e.acceptQueueIsFull() {
// Silently drop the ack as the application can't accept
// the connection at this point. The ack will be
@@ -459,6 +491,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
}
if !synCookiesInUse() {
+ // When not using SYN cookies, as per RFC 793, section 3.9, page 64:
+ // Any acknowledgment is bad if it arrives on a connection still in
+ // the LISTEN state. An acceptable reset segment should be formed
+ // for any arriving ACK-bearing segment. The RST should be
+ // formatted as follows:
+ //
+ // <SEQ=SEG.ACK><CTL=RST>
+ //
// Send a reset as this is an ACK for which there is no
// half open connections and we are not using cookies
// yet.
@@ -519,7 +559,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
n.tsOffset = 0
// Switch state to connected.
+ // We do not use transitionToStateEstablishedLocked here as there is
+ // no handshake state available when doing a SYN cookie based accept.
+ n.stack.Stats().TCP.CurrentEstablished.Increment()
n.state = StateEstablished
+ n.isConnectNotified = true
// Do the delivery in a separate goroutine so
// that we don't block the listen loop in case
@@ -530,6 +574,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// number of goroutines as we do check before
// entering here that there was at least some
// space available in the backlog.
+
+ // Start the protocol goroutine.
+ wq := &waiter.Queue{}
+ n.startAcceptedLoop(wq)
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
go e.deliverAccepted(n)
}
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 5ea036bea..cdd69f360 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -15,6 +15,7 @@
package tcp
import (
+ "encoding/binary"
"sync"
"time"
@@ -22,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -78,9 +80,6 @@ type handshake struct {
// mss is the maximum segment size received from the peer.
mss uint16
- // amss is the maximum segment size advertised by us to the peer.
- amss uint16
-
// sndWndScale is the send window scale, as defined in RFC 1323. A
// negative value means no scaling is supported by the peer.
sndWndScale int
@@ -142,7 +141,32 @@ func (h *handshake) resetState() {
h.flags = header.TCPFlagSyn
h.ackNum = 0
h.mss = 0
- h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24)
+ h.iss = generateSecureISN(h.ep.ID, h.ep.stack.Seed())
+}
+
+// generateSecureISN generates a secure Initial Sequence number based on the
+// recommendation here https://tools.ietf.org/html/rfc6528#page-3.
+func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value {
+ isnHasher := jenkins.Sum32(seed)
+ isnHasher.Write([]byte(id.LocalAddress))
+ isnHasher.Write([]byte(id.RemoteAddress))
+ portBuf := make([]byte, 2)
+ binary.LittleEndian.PutUint16(portBuf, id.LocalPort)
+ isnHasher.Write(portBuf)
+ binary.LittleEndian.PutUint16(portBuf, id.RemotePort)
+ isnHasher.Write(portBuf)
+ // The time period here is 64ns. This is similar to what linux uses
+ // generate a sequence number that overlaps less than one
+ // time per MSL (2 minutes).
+ //
+ // A 64ns clock ticks 10^9/64 = 15625000) times in a second.
+ // To wrap the whole 32 bit space would require
+ // 2^32/1562500 ~ 274 seconds.
+ //
+ // Which sort of guarantees that we won't reuse the ISN for a new
+ // connection for the same tuple for at least 274s.
+ isn := isnHasher.Sum32() + uint32(time.Now().UnixNano()>>6)
+ return seqnum.Value(isn)
}
// effectiveRcvWndScale returns the effective receive window scale to be used.
@@ -194,6 +218,14 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// acceptable if the ack field acknowledges the SYN.
if s.flagIsSet(header.TCPFlagRst) {
if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == h.iss+1 {
+ // RFC 793, page 67, states that "If the RST bit is set [and] If the ACK
+ // was acceptable then signal the user "error: connection reset", drop
+ // the segment, enter CLOSED state, delete TCB, and return."
+ h.ep.mu.Lock()
+ h.ep.workerCleanup = true
+ h.ep.mu.Unlock()
+ // Although the RFC above calls out ECONNRESET, Linux actually returns
+ // ECONNREFUSED here so we do as well.
return tcpip.ErrConnectionRefused
}
return nil
@@ -228,6 +260,11 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// and the handshake is completed.
if s.flagIsSet(header.TCPFlagAck) {
h.state = handshakeCompleted
+
+ h.ep.mu.Lock()
+ h.ep.transitionToStateEstablishedLocked(h)
+ h.ep.mu.Unlock()
+
h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale())
return nil
}
@@ -275,6 +312,15 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
return nil
}
+ // RFC 793, Section 3.9, page 69, states that in the SYN-RCVD state, a
+ // sequence number outside of the window causes an ACK with the proper seq
+ // number and "After sending the acknowledgment, drop the unacceptable
+ // segment and return."
+ if !s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) {
+ h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd)
+ return nil
+ }
+
if s.flagIsSet(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 {
// We received two SYN segments with different sequence
// numbers, so we reset this and restart the whole
@@ -319,6 +365,10 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber)
}
h.state = handshakeCompleted
+ h.ep.mu.Lock()
+ h.ep.transitionToStateEstablishedLocked(h)
+ h.ep.mu.Unlock()
+
return nil
}
@@ -445,7 +495,7 @@ func (h *handshake) execute() *tcpip.Error {
// Send the initial SYN segment and loop until the handshake is
// completed.
- h.ep.amss = mssForRoute(&h.ep.route)
+ h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
@@ -607,19 +657,14 @@ func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data bu
return nil
}
-// sendTCP sends a TCP segment with the provided options via the provided
-// network endpoint and under the provided identity.
-func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, pkt *tcpip.PacketBuffer, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) {
optLen := len(opts)
- // Allocate a buffer for the TCP header.
- hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen)
-
- if rcvWnd > 0xffff {
- rcvWnd = 0xffff
- }
-
+ hdr := &pkt.Header
+ packetSize := pkt.DataSize
+ off := pkt.DataOffset
// Initialize the header.
tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen))
+ pkt.TransportHeader = buffer.View(tcp)
tcp.Encode(&header.TCPFields{
SrcPort: id.LocalPort,
DstPort: id.RemotePort,
@@ -631,7 +676,7 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise
})
copy(tcp[header.TCPMinimumSize:], opts)
- length := uint16(hdr.UsedLength() + data.Size())
+ length := uint16(hdr.UsedLength() + packetSize)
xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
// Only calculate the checksum if offloading isn't supported.
if gso != nil && gso.NeedsCsum {
@@ -641,14 +686,79 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise
// header and data and get the right sum of the TCP packet.
tcp.SetChecksum(xsum)
} else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
- xsum = header.ChecksumVV(data, xsum)
+ xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, off, packetSize)
tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
}
+}
+
+func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+ optLen := len(opts)
+ if rcvWnd > 0xffff {
+ rcvWnd = 0xffff
+ }
+
+ mss := int(gso.MSS)
+ n := (data.Size() + mss - 1) / mss
+
+ // Allocate one big slice for all the headers.
+ hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen
+ buf := make([]byte, n*hdrSize)
+ pkts := make([]tcpip.PacketBuffer, n)
+ for i := range pkts {
+ pkts[i].Header = buffer.NewEmptyPrependableFromView(buf[i*hdrSize:][:hdrSize])
+ }
+
+ size := data.Size()
+ off := 0
+ for i := 0; i < n; i++ {
+ packetSize := mss
+ if packetSize > size {
+ packetSize = size
+ }
+ size -= packetSize
+ pkts[i].DataOffset = off
+ pkts[i].DataSize = packetSize
+ pkts[i].Data = data
+ buildTCPHdr(r, id, &pkts[i], flags, seq, ack, rcvWnd, opts, gso)
+ off += packetSize
+ seq = seq.Add(seqnum.Size(packetSize))
+ }
+ if ttl == 0 {
+ ttl = r.DefaultTTL()
+ }
+ sent, err := r.WritePackets(gso, pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos})
+ if err != nil {
+ r.Stats().TCP.SegmentSendErrors.IncrementBy(uint64(n - sent))
+ }
+ r.Stats().TCP.SegmentsSent.IncrementBy(uint64(sent))
+ return err
+}
+
+// sendTCP sends a TCP segment with the provided options via the provided
+// network endpoint and under the provided identity.
+func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+ optLen := len(opts)
+ if rcvWnd > 0xffff {
+ rcvWnd = 0xffff
+ }
+
+ if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() {
+ return sendTCPBatch(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso)
+ }
+
+ pkt := tcpip.PacketBuffer{
+ Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen),
+ DataOffset: 0,
+ DataSize: data.Size(),
+ Data: data,
+ }
+ buildTCPHdr(r, id, &pkt, flags, seq, ack, rcvWnd, opts, gso)
+
if ttl == 0 {
ttl = r.DefaultTTL()
}
- if err := r.WritePacket(gso, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil {
+ if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, pkt); err != nil {
r.Stats().TCP.SegmentSendErrors.Increment()
return err
}
@@ -754,10 +864,26 @@ func (e *endpoint) handleClose() *tcpip.Error {
func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
// Only send a reset if the connection is being aborted for a reason
// other than receiving a reset.
+ if e.state == StateEstablished || e.state == StateCloseWait {
+ e.stack.Stats().TCP.EstablishedResets.Increment()
+ e.stack.Stats().TCP.CurrentEstablished.Decrement()
+ }
e.state = StateError
e.HardError = err
- if err != tcpip.ErrConnectionReset {
- e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
+ if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout {
+ // The exact sequence number to be used for the RST is the same as the
+ // one used by Linux. We need to handle the case of window being shrunk
+ // which can cause sndNxt to be outside the acceptable window on the
+ // receiver.
+ //
+ // See: https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for more
+ // information.
+ sndWndEnd := e.snd.sndUna.Add(e.snd.sndWnd)
+ resetSeqNum := sndWndEnd
+ if !sndWndEnd.LessThan(e.snd.sndNxt) || e.snd.sndNxt.Size(sndWndEnd) < (1<<e.snd.sndWndScale) {
+ resetSeqNum = e.snd.sndNxt
+ }
+ e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.rcvNxt, 0)
}
}
@@ -771,6 +897,99 @@ func (e *endpoint) completeWorkerLocked() {
}
}
+// transitionToStateEstablisedLocked transitions a given endpoint
+// to an established state using the handshake parameters provided.
+// It also initializes sender/receiver if required.
+func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
+ if e.snd == nil {
+ // Transfer handshake state to TCP connection. We disable
+ // receive window scaling if the peer doesn't support it
+ // (indicated by a negative send window scale).
+ e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
+ }
+ if e.rcv == nil {
+ rcvBufSize := seqnum.Size(e.receiveBufferSize())
+ e.rcvListMu.Lock()
+ e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize)
+ // Bootstrap the auto tuning algorithm. Starting at zero will
+ // result in a really large receive window after the first auto
+ // tuning adjustment.
+ e.rcvAutoParams.prevCopied = int(h.rcvWnd)
+ e.rcvListMu.Unlock()
+ }
+ h.ep.stack.Stats().TCP.CurrentEstablished.Increment()
+ e.state = StateEstablished
+}
+
+// transitionToStateCloseLocked ensures that the endpoint is
+// cleaned up from the transport demuxer, "before" moving to
+// StateClose. This will ensure that no packet will be
+// delivered to this endpoint from the demuxer when the endpoint
+// is transitioned to StateClose.
+func (e *endpoint) transitionToStateCloseLocked() {
+ if e.state == StateClose {
+ return
+ }
+ e.cleanupLocked()
+ e.state = StateClose
+ e.stack.Stats().TCP.EstablishedClosed.Increment()
+}
+
+// tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed
+// segment to any other endpoint other than the current one. This is called
+// only when the endpoint is in StateClose and we want to deliver the segment
+// to any other listening endpoint. We reply with RST if we cannot find one.
+func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
+ ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route)
+ if ep == nil {
+ replyWithReset(s)
+ s.decRef()
+ return
+ }
+ ep.(*endpoint).enqueueSegment(s)
+}
+
+func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
+ if e.rcv.acceptable(s.sequenceNumber, 0) {
+ // RFC 793, page 37 states that "in all states
+ // except SYN-SENT, all reset (RST) segments are
+ // validated by checking their SEQ-fields." So
+ // we only process it if it's acceptable.
+ s.decRef()
+ e.mu.Lock()
+ switch e.state {
+ // In case of a RST in CLOSE-WAIT linux moves
+ // the socket to closed state with an error set
+ // to indicate EPIPE.
+ //
+ // Technically this seems to be at odds w/ RFC.
+ // As per https://tools.ietf.org/html/rfc793#section-2.7
+ // page 69 the behavior for a segment arriving
+ // w/ RST bit set in CLOSE-WAIT is inlined below.
+ //
+ // ESTABLISHED
+ // FIN-WAIT-1
+ // FIN-WAIT-2
+ // CLOSE-WAIT
+
+ // If the RST bit is set then, any outstanding RECEIVEs and
+ // SEND should receive "reset" responses. All segment queues
+ // should be flushed. Users should also receive an unsolicited
+ // general "connection reset" signal. Enter the CLOSED state,
+ // delete the TCB, and return.
+ case StateCloseWait:
+ e.transitionToStateCloseLocked()
+ e.HardError = tcpip.ErrAborted
+ e.mu.Unlock()
+ return false, nil
+ default:
+ e.mu.Unlock()
+ return false, tcpip.ErrConnectionReset
+ }
+ }
+ return true, nil
+}
+
// handleSegments pulls segments from the queue and processes them. It returns
// no error if the protocol loop should continue, an error otherwise.
func (e *endpoint) handleSegments() *tcpip.Error {
@@ -788,14 +1007,34 @@ func (e *endpoint) handleSegments() *tcpip.Error {
}
if s.flagIsSet(header.TCPFlagRst) {
- if e.rcv.acceptable(s.sequenceNumber, 0) {
- // RFC 793, page 37 states that "in all states
- // except SYN-SENT, all reset (RST) segments are
- // validated by checking their SEQ-fields." So
- // we only process it if it's acceptable.
- s.decRef()
- return tcpip.ErrConnectionReset
+ if ok, err := e.handleReset(s); !ok {
+ return err
}
+ } else if s.flagIsSet(header.TCPFlagSyn) {
+ // See: https://tools.ietf.org/html/rfc5961#section-4.1
+ // 1) If the SYN bit is set, irrespective of the sequence number, TCP
+ // MUST send an ACK (also referred to as challenge ACK) to the remote
+ // peer:
+ //
+ // <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK>
+ //
+ // After sending the acknowledgment, TCP MUST drop the unacceptable
+ // segment and stop processing further.
+ //
+ // By sending an ACK, the remote peer is challenged to confirm the loss
+ // of the previous connection and the request to start a new connection.
+ // A legitimate peer, after restart, would not have a TCB in the
+ // synchronized state. Thus, when the ACK arrives, the peer should send
+ // a RST segment back with the sequence number derived from the ACK
+ // field that caused the RST.
+
+ // This RST will confirm that the remote peer has indeed closed the
+ // previous connection. Upon receipt of a valid RST, the local TCP
+ // endpoint MUST terminate its connection. The local TCP endpoint
+ // should then rely on SYN retransmission from the remote end to
+ // re-establish the connection.
+
+ e.snd.sendAck()
} else if s.flagIsSet(header.TCPFlagAck) {
// Patch the window size in the segment according to the
// send window scale.
@@ -804,7 +1043,33 @@ func (e *endpoint) handleSegments() *tcpip.Error {
// RFC 793, page 41 states that "once in the ESTABLISHED
// state all segments must carry current acknowledgment
// information."
- e.rcv.handleRcvdSegment(s)
+ drop, err := e.rcv.handleRcvdSegment(s)
+ if err != nil {
+ s.decRef()
+ return err
+ }
+ if drop {
+ s.decRef()
+ continue
+ }
+
+ // Now check if the received segment has caused us to transition
+ // to a CLOSED state, if yes then terminate processing and do
+ // not invoke the sender.
+ e.mu.RLock()
+ state := e.state
+ e.mu.RUnlock()
+ if state == StateClose {
+ // When we get into StateClose while processing from the queue,
+ // return immediately and let the protocolMainloop handle it.
+ //
+ // We can reach StateClose only while processing a previous segment
+ // or a notification from the protocolMainLoop (caller goroutine).
+ // This means that with this return, the segment dequeue below can
+ // never occur on a closed endpoint.
+ s.decRef()
+ return nil
+ }
e.snd.handleRcvdSegment(s)
}
s.decRef()
@@ -830,14 +1095,27 @@ func (e *endpoint) handleSegments() *tcpip.Error {
// keepalive packets periodically when the connection is idle. If we don't hear
// from the other side after a number of tries, we terminate the connection.
func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
+ e.mu.RLock()
+ userTimeout := e.userTimeout
+ e.mu.RUnlock()
+
e.keepalive.Lock()
if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() {
e.keepalive.Unlock()
return nil
}
+ // If a userTimeout is set then abort the connection if it is
+ // exceeded.
+ if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 {
+ e.keepalive.Unlock()
+ e.stack.Stats().TCP.EstablishedTimedout.Increment()
+ return tcpip.ErrTimeout
+ }
+
if e.keepalive.unacked >= e.keepalive.count {
e.keepalive.Unlock()
+ e.stack.Stats().TCP.EstablishedTimedout.Increment()
return tcpip.ErrTimeout
}
@@ -854,7 +1132,6 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
// whether it is enabled for this endpoint.
func (e *endpoint) resetKeepaliveTimer(receivedData bool) {
e.keepalive.Lock()
- defer e.keepalive.Unlock()
if receivedData {
e.keepalive.unacked = 0
}
@@ -862,6 +1139,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) {
// data to send.
if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt {
e.keepalive.timer.disable()
+ e.keepalive.Unlock()
return
}
if e.keepalive.unacked > 0 {
@@ -869,6 +1147,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) {
} else {
e.keepalive.timer.enable(e.keepalive.idle)
}
+ e.keepalive.Unlock()
}
// disableKeepaliveTimer stops the keepalive timer.
@@ -903,7 +1182,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
}
e.mu.Unlock()
-
// When the protocol loop exits we should wake up our waiters.
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
@@ -932,21 +1210,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
return err
}
-
- // Transfer handshake state to TCP connection. We disable
- // receive window scaling if the peer doesn't support it
- // (indicated by a negative send window scale).
- e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
-
- rcvBufSize := seqnum.Size(e.receiveBufferSize())
- e.rcvListMu.Lock()
- e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize)
- // boot strap the auto tuning algorithm. Starting at zero will
- // result in a large step function on the first proper causing
- // the window to just go to a really large value after the first
- // RTT itself.
- e.rcvAutoParams.prevCopied = initialRcvWnd
- e.rcvListMu.Unlock()
}
e.keepalive.timer.init(&e.keepalive.waker)
@@ -954,7 +1217,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
// Tell waiters that the endpoint is connected and writable.
e.mu.Lock()
- e.state = StateEstablished
drained := e.drainDone != nil
e.mu.Unlock()
if drained {
@@ -985,13 +1247,20 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
{
w: &closeWaker,
f: func() *tcpip.Error {
- return tcpip.ErrConnectionAborted
+ // This means the socket is being closed due
+ // to the TCP_FIN_WAIT2 timeout was hit. Just
+ // mark the socket as closed.
+ e.mu.Lock()
+ e.transitionToStateCloseLocked()
+ e.mu.Unlock()
+ return nil
},
},
{
w: &e.snd.resendWaker,
f: func() *tcpip.Error {
if !e.snd.retransmitTimerExpired() {
+ e.stack.Stats().TCP.EstablishedTimedout.Increment()
return tcpip.ErrTimeout
}
return nil
@@ -1028,17 +1297,18 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.resetConnectionLocked(tcpip.ErrConnectionAborted)
e.mu.Unlock()
}
+
if n&notifyClose != 0 && closeTimer == nil {
- // Reset the connection 3 seconds after
- // the endpoint has been closed.
- //
- // The timer could fire in background
- // when the endpoint is drained. That's
- // OK as the loop here will not honor
- // the firing until the undrain arrives.
- closeTimer = time.AfterFunc(3*time.Second, func() {
- closeWaker.Assert()
- })
+ e.mu.Lock()
+ if e.state == StateFinWait2 && e.closed {
+ // The socket has been closed and we are in FIN_WAIT2
+ // so start the FIN_WAIT2 timer.
+ closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() {
+ closeWaker.Assert()
+ })
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ }
+ e.mu.Unlock()
}
if n&notifyKeepaliveChanged != 0 {
@@ -1054,12 +1324,20 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
return err
}
}
- if e.state != StateError {
+ if e.state != StateClose && e.state != StateError {
+ // Only block the worker if the endpoint
+ // is not in closed state or error state.
close(e.drainDone)
<-e.undrain
}
}
+ if n&notifyTickleWorker != 0 {
+ // Just a tickle notification. No need to do
+ // anything.
+ return nil
+ }
+
return nil
},
},
@@ -1086,15 +1364,16 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
}
e.rcvListMu.Unlock()
- e.mu.RLock()
+ e.mu.Lock()
if e.workerCleanup {
e.notifyProtocolGoroutine(notifyClose)
}
- e.mu.RUnlock()
// Main loop. Handle segments until both send and receive ends of the
// connection have completed.
- for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList {
+
+ for e.state != StateTimeWait && e.state != StateClose && e.state != StateError {
+ e.mu.Unlock()
e.workMu.Unlock()
v, _ := s.Fetch(true)
e.workMu.Lock()
@@ -1110,15 +1389,168 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
return nil
}
+ e.mu.Lock()
+ }
+
+ state := e.state
+ e.mu.Unlock()
+ var reuseTW func()
+ if state == StateTimeWait {
+ // Disable close timer as we now entering real TIME_WAIT.
+ if closeTimer != nil {
+ closeTimer.Stop()
+ }
+ // Mark the current sleeper done so as to free all associated
+ // wakers.
+ s.Done()
+ // Wake up any waiters before we enter TIME_WAIT.
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ reuseTW = e.doTimeWait()
}
// Mark endpoint as closed.
e.mu.Lock()
if e.state != StateError {
- e.state = StateClose
+ e.stack.Stats().TCP.CurrentEstablished.Decrement()
+ e.transitionToStateCloseLocked()
}
+
// Lock released below.
epilogue()
+ // epilogue removes the endpoint from the transport-demuxer and
+ // unlocks e.mu. Now that no new segments can get enqueued to this
+ // endpoint, try to re-match the segment to a different endpoint
+ // as the current endpoint is closed.
+ for {
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ break
+ }
+
+ e.tryDeliverSegmentFromClosedEndpoint(s)
+ }
+
+ // A new SYN was received during TIME_WAIT and we need to abort
+ // the timewait and redirect the segment to the listener queue
+ if reuseTW != nil {
+ reuseTW()
+ }
+
return nil
}
+
+// handleTimeWaitSegments processes segments received during TIME_WAIT
+// state.
+func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()) {
+ checkRequeue := true
+ for i := 0; i < maxSegmentsPerWake; i++ {
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ checkRequeue = false
+ break
+ }
+ extTW, newSyn := e.rcv.handleTimeWaitSegment(s)
+ if newSyn {
+ info := e.EndpointInfo.TransportEndpointInfo
+ newID := info.ID
+ newID.RemoteAddress = ""
+ newID.RemotePort = 0
+ netProtos := []tcpip.NetworkProtocolNumber{info.NetProto}
+ // If the local address is an IPv4 address then also
+ // look for IPv6 dual stack endpoints that might be
+ // listening on the local address.
+ if newID.LocalAddress.To4() != "" {
+ netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber}
+ }
+ for _, netProto := range netProtos {
+ if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil {
+ tcpEP := listenEP.(*endpoint)
+ if EndpointState(tcpEP.State()) == StateListen {
+ reuseTW = func() {
+ tcpEP.enqueueSegment(s)
+ }
+ // We explicitly do not decRef
+ // the segment as it's still
+ // valid and being reflected to
+ // a listening endpoint.
+ return false, reuseTW
+ }
+ }
+ }
+ }
+ if extTW {
+ extendTimeWait = true
+ }
+ s.decRef()
+ }
+ if checkRequeue && !e.segmentQueue.empty() {
+ e.newSegmentWaker.Assert()
+ }
+ return extendTimeWait, nil
+}
+
+// doTimeWait is responsible for handling the TCP behaviour once a socket
+// enters the TIME_WAIT state. Optionally it can return a closure that
+// should be executed after releasing the endpoint registrations. This is
+// done in cases where a new SYN is received during TIME_WAIT that carries
+// a sequence number larger than one see on the connection.
+func (e *endpoint) doTimeWait() (twReuse func()) {
+ // Trigger a 2 * MSL time wait state. During this period
+ // we will drop all incoming segments.
+ // NOTE: On Linux this is not configurable and is fixed at 60 seconds.
+ timeWaitDuration := DefaultTCPTimeWaitTimeout
+
+ // Get the stack wide configuration.
+ var tcpTW tcpip.TCPTimeWaitTimeoutOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &tcpTW); err == nil {
+ timeWaitDuration = time.Duration(tcpTW)
+ }
+
+ const newSegment = 1
+ const notification = 2
+ const timeWaitDone = 3
+
+ s := sleep.Sleeper{}
+ s.AddWaker(&e.newSegmentWaker, newSegment)
+ s.AddWaker(&e.notificationWaker, notification)
+
+ var timeWaitWaker sleep.Waker
+ s.AddWaker(&timeWaitWaker, timeWaitDone)
+ timeWaitTimer := time.AfterFunc(timeWaitDuration, timeWaitWaker.Assert)
+ defer timeWaitTimer.Stop()
+
+ for {
+ e.workMu.Unlock()
+ v, _ := s.Fetch(true)
+ e.workMu.Lock()
+ switch v {
+ case newSegment:
+ extendTimeWait, reuseTW := e.handleTimeWaitSegments()
+ if reuseTW != nil {
+ return reuseTW
+ }
+ if extendTimeWait {
+ timeWaitTimer.Reset(timeWaitDuration)
+ }
+ case notification:
+ n := e.fetchNotifications()
+ if n&notifyClose != 0 {
+ return nil
+ }
+ if n&notifyDrain != 0 {
+ for !e.segmentQueue.empty() {
+ // Ignore extending TIME_WAIT during a
+ // save. For sockets in TIME_WAIT we just
+ // terminate the TIME_WAIT early.
+ e.handleTimeWaitSegments()
+ }
+ close(e.drainDone)
+ <-e.undrain
+ return nil
+ }
+ case timeWaitDone:
+ return nil
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index a1b784b49..fe629aa40 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -30,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tmutex"
@@ -121,6 +122,11 @@ const (
notifyReset
notifyKeepaliveChanged
notifyMSSChanged
+ // notifyTickleWorker is used to tickle the protocol main loop during a
+ // restore after we update the endpoint state to the correct one. This
+ // ensures the loop terminates if the final state of the endpoint is
+ // say TIME_WAIT.
+ notifyTickleWorker
)
// SACKInfo holds TCP SACK related information for a given endpoint.
@@ -287,6 +293,7 @@ type endpoint struct {
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue `state:"wait"`
+ uniqueID uint64
// lastError represents the last error that the endpoint reported;
// access to it is protected by the following mutex.
@@ -319,6 +326,11 @@ type endpoint struct {
state EndpointState `state:".(EndpointState)"`
+ // origEndpointState is only used during a restore phase to save the
+ // endpoint state at restore time as the socket is moved to it's correct
+ // state.
+ origEndpointState EndpointState `state:"nosave"`
+
isPortReserved bool `state:"manual"`
isRegistered bool
boundNICID tcpip.NICID `state:"manual"`
@@ -330,6 +342,11 @@ type endpoint struct {
// disabling SO_BROADCAST, albeit as a NOOP.
broadcast bool
+ // Values used to reserve a port or register a transport endpoint
+ // (which ever happens first).
+ boundBindToDevice tcpip.NICID
+ boundPortFlags ports.Flags
+
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
// endpoints with v6only set to false, this could include multiple
@@ -411,7 +428,7 @@ type endpoint struct {
// userMSS if non-zero is the MSS value explicitly set by the user
// for this endpoint using the TCP_MAXSEG setsockopt.
- userMSS int
+ userMSS uint16
// The following fields are used to manage the send buffer. When
// segments are ready to be sent, they are added to sndQueue and the
@@ -458,6 +475,12 @@ type endpoint struct {
// without hearing a response, the connection is closed.
keepalive keepalive
+ // userTimeout if non-zero specifies a user specified timeout for
+ // a connection w/ pending data to send. A connection that has pending
+ // unacked data will be forcibily aborted if the timeout is reached
+ // without any data being acked.
+ userTimeout time.Duration
+
// pendingAccepted is a synchronization primitive used to track number
// of connections that are queued up to be delivered to the accepted
// channel. We use this to ensure that all goroutines blocked on writing
@@ -502,6 +525,36 @@ type endpoint struct {
// TODO(b/142022063): Add ability to save and restore per endpoint stats.
stats Stats `state:"nosave"`
+
+ // tcpLingerTimeout is the maximum amount of a time a socket
+ // a socket stays in TIME_WAIT state before being marked
+ // closed.
+ tcpLingerTimeout time.Duration
+
+ // closed indicates that the user has called closed on the
+ // endpoint and at this point the endpoint is only around
+ // to complete the TCP shutdown.
+ closed bool
+}
+
+// UniqueID implements stack.TransportEndpoint.UniqueID.
+func (e *endpoint) UniqueID() uint64 {
+ return e.uniqueID
+}
+
+// calculateAdvertisedMSS calculates the MSS to advertise.
+//
+// If userMSS is non-zero and is not greater than the maximum possible MSS for
+// r, it will be used; otherwise, the maximum possible MSS will be used.
+func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 {
+ // The maximum possible MSS is dependent on the route.
+ maxMSS := mssForRoute(&r)
+
+ if userMSS != 0 && userMSS < maxMSS {
+ return userMSS
+ }
+
+ return maxMSS
}
// StopWork halts packet processing. Only to be used in tests.
@@ -550,6 +603,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
interval: 75 * time.Second,
count: 9,
},
+ uniqueID: s.UniqueID(),
}
var ss SendBufferSizeOption
@@ -572,6 +626,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.rcvAutoParams.disabled = !bool(mrb)
}
+ var de DelayEnabled
+ if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de {
+ e.SetSockOptInt(tcpip.DelayOption, 1)
+ }
+
+ var tcpLT tcpip.TCPLingerTimeoutOption
+ if err := s.TransportProtocolOption(ProtocolNumber, &tcpLT); err == nil {
+ e.tcpLingerTimeout = time.Duration(tcpLT)
+ }
+
if p := s.GetTCPProbe(); p != nil {
e.probe = p
}
@@ -659,6 +723,13 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) {
// with it. It must be called only once and with no other concurrent calls to
// the endpoint.
func (e *endpoint) Close() {
+ e.mu.Lock()
+ closed := e.closed
+ e.mu.Unlock()
+ if closed {
+ return
+ }
+
// Issue a shutdown so that the peer knows we won't send any more data
// if we're connected, or stop accepting if we're listening.
e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
@@ -671,14 +742,18 @@ func (e *endpoint) Close() {
// in Listen() when trying to register.
if e.state == StateListen && e.isPortReserved {
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
e.isRegistered = false
}
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
e.isPortReserved = false
+ e.boundBindToDevice = 0
+ e.boundPortFlags = ports.Flags{}
}
+ // Mark endpoint as closed.
+ e.closed = true
// Either perform the local cleanup or kick the worker to make sure it
// knows it needs to cleanup.
tcpip.AddDanglingEndpoint(e)
@@ -704,9 +779,7 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() {
go func() {
defer close(done)
for n := range e.acceptedChan {
- n.mu.Lock()
- n.resetConnectionLocked(tcpip.ErrConnectionAborted)
- n.mu.Unlock()
+ n.notifyProtocolGoroutine(notifyReset)
n.Close()
}
}()
@@ -732,16 +805,19 @@ func (e *endpoint) cleanupLocked() {
e.workerCleanup = false
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
e.isRegistered = false
}
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
e.isPortReserved = false
}
+ e.boundBindToDevice = 0
+ e.boundPortFlags = ports.Flags{}
e.route.Release()
+ e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
@@ -752,7 +828,9 @@ func (e *endpoint) initialReceiveWindow() int {
if rcvWnd > math.MaxUint16 {
rcvWnd = math.MaxUint16
}
- routeWnd := InitialCwnd * int(mssForRoute(&e.route)) * 2
+
+ // Use the user supplied MSS, if available.
+ routeWnd := InitialCwnd * int(calculateAdvertisedMSS(e.userMSS, e.route)) * 2
if rcvWnd > routeWnd {
rcvWnd = routeWnd
}
@@ -1133,16 +1211,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
e.sndBufMu.Unlock()
return nil
- default:
- return nil
- }
-}
-
-// SetSockOpt sets a socket option.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- // Lower 2 bits represents ECN bits. RFC 3168, section 23.1
- const inetECNMask = 3
- switch v := opt.(type) {
case tcpip.DelayOption:
if v == 0 {
atomic.StoreUint32(&e.delay, 0)
@@ -1154,6 +1222,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ default:
+ return nil
+ }
+}
+
+// SetSockOpt sets a socket option.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ // Lower 2 bits represents ECN bits. RFC 3168, section 23.1
+ const inetECNMask = 3
+ switch v := opt.(type) {
case tcpip.CorkOption:
if v == 0 {
atomic.StoreUint32(&e.cork, 0)
@@ -1184,9 +1262,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.bindToDevice = 0
return nil
}
- for nicid, nic := range e.stack.NICInfo() {
+ for nicID, nic := range e.stack.NICInfo() {
if nic.Name == string(v) {
- e.bindToDevice = nicid
+ e.bindToDevice = nicID
return nil
}
}
@@ -1206,7 +1284,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrInvalidOptionValue
}
e.mu.Lock()
- e.userMSS = int(userMSS)
+ e.userMSS = uint16(userMSS)
e.mu.Unlock()
e.notifyProtocolGoroutine(notifyMSSChanged)
return nil
@@ -1262,6 +1340,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
return nil
+ case tcpip.TCPUserTimeoutOption:
+ e.mu.Lock()
+ e.userTimeout = time.Duration(v)
+ e.mu.Unlock()
+ return nil
+
case tcpip.BroadcastOption:
e.mu.Lock()
e.broadcast = v != 0
@@ -1319,6 +1403,28 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case tcpip.TCPLingerTimeoutOption:
+ e.mu.Lock()
+ if v < 0 {
+ // Same as effectively disabling TCPLinger timeout.
+ v = 0
+ }
+ var stkTCPLingerTimeout tcpip.TCPLingerTimeoutOption
+ if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &stkTCPLingerTimeout); err != nil {
+ // We were unable to retrieve a stack config, just use
+ // the DefaultTCPLingerTimeout.
+ if v > tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) {
+ stkTCPLingerTimeout = tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout)
+ }
+ }
+ // Cap it to the stack wide TCPLinger timeout.
+ if v > stkTCPLingerTimeout {
+ v = stkTCPLingerTimeout
+ }
+ e.tcpLingerTimeout = time.Duration(v)
+ e.mu.Unlock()
+ return nil
+
default:
return nil
}
@@ -1345,6 +1451,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
+
case tcpip.SendBufferSizeOption:
e.sndBufMu.Lock()
v := e.sndBufSize
@@ -1357,8 +1464,16 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
e.rcvListMu.Unlock()
return v, nil
+ case tcpip.DelayOption:
+ var o int
+ if v := atomic.LoadUint32(&e.delay); v != 0 {
+ o = 1
+ }
+ return o, nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
}
- return -1, tcpip.ErrUnknownProtocolOption
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
@@ -1379,13 +1494,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = header.TCPDefaultMSS
return nil
- case *tcpip.DelayOption:
- *o = 0
- if v := atomic.LoadUint32(&e.delay); v != 0 {
- *o = 1
- }
- return nil
-
case *tcpip.CorkOption:
*o = 0
if v := atomic.LoadUint32(&e.cork); v != 0 {
@@ -1496,6 +1604,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.keepalive.Unlock()
return nil
+ case *tcpip.TCPUserTimeoutOption:
+ e.mu.Lock()
+ *o = tcpip.TCPUserTimeoutOption(e.userTimeout)
+ e.mu.Unlock()
+ return nil
+
case *tcpip.OutOfBandInlineOption:
// We don't currently support disabling this option.
*o = 1
@@ -1530,6 +1644,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.mu.RUnlock()
return nil
+ case *tcpip.TCPLingerTimeoutOption:
+ e.mu.Lock()
+ *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout)
+ e.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -1602,7 +1722,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
return tcpip.ErrAlreadyConnected
}
- nicid := addr.NIC
+ nicID := addr.NIC
switch e.state {
case StateBound:
// If we're already bound to a NIC but the caller is requesting
@@ -1611,11 +1731,11 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
break
}
- if nicid != 0 && nicid != e.boundNICID {
+ if nicID != 0 && nicID != e.boundNICID {
return tcpip.ErrNoRoute
}
- nicid = e.boundNICID
+ nicID = e.boundNICID
case StateInitial:
// Nothing to do. We'll eventually fill-in the gaps in the ID (if any)
@@ -1634,7 +1754,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicid, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */)
+ r, err := e.stack.FindRoute(nicID, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */)
if err != nil {
return err
}
@@ -1649,7 +1769,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
if e.ID.LocalPort != 0 {
// The endpoint is bound to a port, attempt to register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice)
if err != nil {
return err
}
@@ -1664,7 +1784,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
// src IP to ensure that for a given tuple (srcIP, destIP,
// destPort) the offset used as a starting point is the same to
// ensure that we can cycle through the port space effectively.
- h := jenkins.Sum32(e.stack.PortSeed())
+ h := jenkins.Sum32(e.stack.Seed())
h.Write([]byte(e.ID.LocalAddress))
h.Write([]byte(e.ID.RemoteAddress))
portBuf := make([]byte, 2)
@@ -1678,15 +1798,18 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
// reusePort is false below because connect cannot reuse a port even if
// reusePort was set.
- if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, false /* reusePort */, e.bindToDevice) {
+ if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, ports.Flags{LoadBalanced: false}, e.bindToDevice) {
return false, nil
}
id := e.ID
id.LocalPort = p
- switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) {
+ switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) {
case nil:
+ // Port picking successful. Save the details of
+ // the selected port.
e.ID = id
+ e.boundBindToDevice = e.bindToDevice
return true, nil
case tcpip.ErrPortInUse:
return false, nil
@@ -1702,14 +1825,14 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
// before Connect: in such a case we don't want to hold on to
// reservations anymore.
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
e.isPortReserved = false
}
e.isRegistered = true
e.state = StateConnecting
e.route = r.Clone()
- e.boundNICID = nicid
+ e.boundNICID = nicID
e.effectiveNetProtos = netProtos
e.connectingAddress = connectingAddr
@@ -1729,6 +1852,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
e.segmentQueue.mu.Unlock()
e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0)
e.state = StateEstablished
+ e.stack.Stats().TCP.CurrentEstablished.Increment()
}
if run {
@@ -1749,9 +1873,8 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
// peer.
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
e.mu.Lock()
- defer e.mu.Unlock()
e.shutdownFlags |= flags
-
+ finQueued := false
switch {
case e.state.connected():
// Close for read.
@@ -1766,6 +1889,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// the connection with a RST.
if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 {
e.notifyProtocolGoroutine(notifyReset)
+ e.mu.Unlock()
return nil
}
}
@@ -1784,14 +1908,11 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
s := newSegmentFromView(&e.route, e.ID, nil)
e.sndQueue.PushBack(s)
e.sndBufInQueue++
-
+ finQueued = true
// Mark endpoint as closed.
e.sndClosed = true
e.sndBufMu.Unlock()
-
- // Tell protocol goroutine to close.
- e.sndCloseWaker.Assert()
}
case e.state == StateListen:
@@ -1799,11 +1920,20 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
if flags&tcpip.ShutdownRead != 0 {
e.notifyProtocolGoroutine(notifyClose)
}
-
default:
+ e.mu.Unlock()
return tcpip.ErrNotConnected
}
-
+ e.mu.Unlock()
+ if finQueued {
+ if e.workMu.TryLock() {
+ e.handleClose()
+ e.workMu.Unlock()
+ } else {
+ // Tell protocol goroutine to close.
+ e.sndCloseWaker.Assert()
+ }
+ }
return nil
}
@@ -1844,6 +1974,15 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
return nil
}
+ if e.state == StateInitial {
+ // The listen is called on an unbound socket, the socket is
+ // automatically bound to a random free port with the local
+ // address set to INADDR_ANY.
+ if err := e.bindLocked(tcpip.FullAddress{}); err != nil {
+ return err
+ }
+ }
+
// Endpoint must be bound before it can transition to listen mode.
if e.state != StateBound {
e.stats.ReadErrors.InvalidEndpointState.Increment()
@@ -1851,7 +1990,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
}
// Register the endpoint.
- if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice); err != nil {
return err
}
@@ -1895,12 +2034,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrWouldBlock
}
- // Start the protocol goroutine.
- wq := &waiter.Queue{}
- n.startAcceptedLoop(wq)
- e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
-
- return n, wq, nil
+ return n, n.waiterQueue, nil
}
// Bind binds the endpoint to a specific local port and optionally address.
@@ -1908,6 +2042,10 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
e.mu.Lock()
defer e.mu.Unlock()
+ return e.bindLocked(addr)
+}
+
+func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// Don't allow binding once endpoint is not in the initial state
// anymore. This is because once the endpoint goes into a connected or
// listen state, it is already bound.
@@ -1932,26 +2070,33 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
}
}
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice)
+ flags := ports.Flags{
+ LoadBalanced: e.reusePort,
+ }
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, flags, e.bindToDevice)
if err != nil {
return err
}
+ e.boundBindToDevice = e.bindToDevice
+ e.boundPortFlags = flags
e.isPortReserved = true
e.effectiveNetProtos = netProtos
e.ID.LocalPort = port
// Any failures beyond this point must remove the port registration.
- defer func(bindToDevice tcpip.NICID) {
+ defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) {
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice)
e.isPortReserved = false
e.effectiveNetProtos = nil
e.ID.LocalPort = 0
e.ID.LocalAddress = ""
e.boundNICID = 0
+ e.boundBindToDevice = 0
+ e.boundPortFlags = ports.Flags{}
}
- }(e.bindToDevice)
+ }(e.boundPortFlags, e.boundBindToDevice)
// If an address is specified, we must ensure that it's one of our
// local addresses.
@@ -2001,8 +2146,8 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
- s := newSegment(r, id, vv)
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
+ s := newSegment(r, id, pkt)
if !s.parse() {
e.stack.Stats().MalformedRcvdPackets.Increment()
e.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
@@ -2025,6 +2170,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
e.stack.Stats().TCP.ResetsReceived.Increment()
}
+ e.enqueueSegment(s)
+}
+
+func (e *endpoint) enqueueSegment(s *segment) {
// Send packet to worker goroutine.
if e.segmentQueue.enqueue(s) {
e.newSegmentWaker.Assert()
@@ -2037,7 +2186,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
switch typ {
case stack.ControlPacketTooBig:
e.sndBufMu.Lock()
@@ -2327,11 +2476,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
return s
}
-func (e *endpoint) initGSO() {
- if e.route.Capabilities()&stack.CapabilityGSO == 0 {
- return
- }
-
+func (e *endpoint) initHardwareGSO() {
gso := &stack.GSO{}
switch e.route.NetProto {
case header.IPv4ProtocolNumber:
@@ -2349,6 +2494,18 @@ func (e *endpoint) initGSO() {
e.gso = gso
}
+func (e *endpoint) initGSO() {
+ if e.route.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ e.initHardwareGSO()
+ } else if e.route.Capabilities()&stack.CapabilitySoftwareGSO != 0 {
+ e.gso = &stack.GSO{
+ MaxSize: e.route.GSOMaxSize(),
+ Type: stack.GSOSW,
+ NeedsCsum: false,
+ }
+ }
+}
+
// State implements tcpip.Endpoint.State. It exports the endpoint's protocol
// state for diagnostics.
func (e *endpoint) State() uint32 {
@@ -2371,6 +2528,23 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
return &e.stats
}
+// Wait implements stack.TransportEndpoint.Wait.
+func (e *endpoint) Wait() {
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp)
+ defer e.waiterQueue.EventUnregister(&waitEntry)
+ for {
+ e.mu.Lock()
+ running := e.workerRunning
+ e.mu.Unlock()
+ if !running {
+ break
+ }
+ <-notifyCh
+ }
+}
+
func mssForRoute(r *stack.Route) uint16 {
+ // TODO(b/143359391): Respect TCP Min and Max size.
return uint16(r.MTU() - header.TCPMinimumSize)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index eae17237e..7aa4c3f0e 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -78,7 +78,7 @@ func (e *endpoint) beforeSave() {
}
fallthrough
case StateError, StateClose:
- for e.state == StateError && e.workerRunning {
+ for (e.state == StateError || e.state == StateClose) && e.workerRunning {
e.mu.Unlock()
time.Sleep(100 * time.Millisecond)
e.mu.Lock()
@@ -165,6 +165,12 @@ func (e *endpoint) loadState(state EndpointState) {
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
+ // Freeze segment queue before registering to prevent any segments
+ // from being delivered while it is being restored.
+ e.origEndpointState = e.state
+ // Restore the endpoint to InitialState as it will be moved to
+ // its origEndpointState during Resume.
+ e.state = StateInitial
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
@@ -173,8 +179,8 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
e.segmentQueue.setLimit(MaxUnprocessedSegments)
e.workMu.Init()
+ state := e.origEndpointState
- state := e.state
switch state {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
var ss SendBufferSizeOption
@@ -189,12 +195,13 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
bind := func() {
- e.state = StateInitial
if len(e.BindAddr) == 0 {
e.BindAddr = e.ID.LocalAddress
}
- if err := e.Bind(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}); err != nil {
- panic("endpoint binding failed: " + err.String())
+ addr := e.BindAddr
+ port := e.ID.LocalPort
+ if err := e.Bind(tcpip.FullAddress{Addr: addr, Port: port}); err != nil {
+ panic(fmt.Sprintf("endpoint binding [%v]:%d failed: %v", addr, port, err))
}
}
@@ -217,6 +224,16 @@ func (e *endpoint) Resume(s *stack.Stack) {
if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
panic("endpoint connecting failed: " + err.String())
}
+ e.mu.Lock()
+ e.state = e.origEndpointState
+ closed := e.closed
+ e.mu.Unlock()
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ if state == StateFinWait2 && closed {
+ // If the endpoint has been closed then make sure we notify so
+ // that the FIN_WAIT2 timer is started after a restore.
+ e.notifyProtocolGoroutine(notifyClose)
+ }
connectedLoading.Done()
case StateListen:
tcpip.AsyncLoading.Add(1)
@@ -263,8 +280,12 @@ func (e *endpoint) Resume(s *stack.Stack) {
tcpip.AsyncLoading.Done()
}()
}
- fallthrough
+ e.state = StateClose
+ e.stack.CompleteTransportEndpointCleanup(e)
+ tcpip.DeleteDanglingEndpoint(e)
case StateError:
+ e.state = StateError
+ e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
}
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 63666f0b3..4983bca81 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -18,7 +18,6 @@ import (
"sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -63,8 +62,8 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
- s := newSegment(r, id, vv)
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool {
+ s := newSegment(r, id, pkt)
defer s.decRef()
// We only care about well-formed SYN packets.
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index db40785d3..bc718064c 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -23,6 +23,7 @@ package tcp
import (
"strings"
"sync"
+ "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -54,12 +55,23 @@ const (
// MaxUnprocessedSegments is the maximum number of unprocessed segments
// that can be queued for a given endpoint.
MaxUnprocessedSegments = 300
+
+ // DefaultTCPLingerTimeout is the amount of time that sockets linger in
+ // FIN_WAIT_2 state before being marked closed.
+ DefaultTCPLingerTimeout = 60 * time.Second
+
+ // DefaultTCPTimeWaitTimeout is the amount of time that sockets linger
+ // in TIME_WAIT state before being marked closed.
+ DefaultTCPTimeWaitTimeout = 60 * time.Second
)
// SACKEnabled option can be used to enable SACK support in the TCP
// protocol. See: https://tools.ietf.org/html/rfc2018.
type SACKEnabled bool
+// DelayEnabled option can be used to enable Nagle's algorithm in the TCP protocol.
+type DelayEnabled bool
+
// SendBufferSizeOption allows the default, min and max send buffer sizes for
// TCP endpoints to be queried or configured.
type SendBufferSizeOption struct {
@@ -84,11 +96,14 @@ const (
type protocol struct {
mu sync.Mutex
sackEnabled bool
+ delayEnabled bool
sendBufferSize SendBufferSizeOption
recvBufferSize ReceiveBufferSizeOption
congestionControl string
availableCongestionControl []string
moderateReceiveBuffer bool
+ tcpLingerTimeout time.Duration
+ tcpTimeWaitTimeout time.Duration
}
// Number returns the tcp protocol number.
@@ -97,7 +112,7 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
}
// NewEndpoint creates a new tcp endpoint.
-func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, waiterQueue), nil
}
@@ -126,8 +141,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// a reset is sent in response to any incoming segment except another reset. In
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
-func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
- s := newSegment(r, id, vv)
+func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool {
+ s := newSegment(r, id, pkt)
defer s.decRef()
if !s.parse() || !s.csumValid {
@@ -147,13 +162,26 @@ func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Transpo
func replyWithReset(s *segment) {
// Get the seqnum from the packet if the ack flag is set.
seq := seqnum.Value(0)
+ ack := seqnum.Value(0)
+ flags := byte(header.TCPFlagRst)
+ // As per RFC 793 page 35 (Reset Generation)
+ // 1. If the connection does not exist (CLOSED) then a reset is sent
+ // in response to any incoming segment except another reset. In
+ // particular, SYNs addressed to a non-existent connection are rejected
+ // by this means.
+
+ // If the incoming segment has an ACK field, the reset takes its
+ // sequence number from the ACK field of the segment, otherwise the
+ // reset has sequence number zero and the ACK field is set to the sum
+ // of the sequence number and segment length of the incoming segment.
+ // The connection remains in the CLOSED state.
if s.flagIsSet(header.TCPFlagAck) {
seq = s.ackNumber
+ } else {
+ flags |= header.TCPFlagAck
+ ack = s.sequenceNumber.Add(s.logicalLen())
}
-
- ack := s.sequenceNumber.Add(s.logicalLen())
-
- sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */)
+ sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */)
}
// SetOption implements TransportProtocol.SetOption.
@@ -165,6 +193,12 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
+ case DelayEnabled:
+ p.mu.Lock()
+ p.delayEnabled = bool(v)
+ p.mu.Unlock()
+ return nil
+
case SendBufferSizeOption:
if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
return tcpip.ErrInvalidOptionValue
@@ -202,6 +236,24 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
+ case tcpip.TCPLingerTimeoutOption:
+ if v < 0 {
+ v = 0
+ }
+ p.mu.Lock()
+ p.tcpLingerTimeout = time.Duration(v)
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.TCPTimeWaitTimeoutOption:
+ if v < 0 {
+ v = 0
+ }
+ p.mu.Lock()
+ p.tcpTimeWaitTimeout = time.Duration(v)
+ p.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -216,6 +268,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
+ case *DelayEnabled:
+ p.mu.Lock()
+ *v = DelayEnabled(p.delayEnabled)
+ p.mu.Unlock()
+ return nil
+
case *SendBufferSizeOption:
p.mu.Lock()
*v = p.sendBufferSize
@@ -246,6 +304,18 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
+ case *tcpip.TCPLingerTimeoutOption:
+ p.mu.Lock()
+ *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout)
+ p.mu.Unlock()
+ return nil
+
+ case *tcpip.TCPTimeWaitTimeoutOption:
+ p.mu.Lock()
+ *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout)
+ p.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -258,5 +328,7 @@ func NewProtocol() stack.TransportProtocol {
recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
congestionControl: ccReno,
availableCongestionControl: []string{ccReno, ccCubic},
+ tcpLingerTimeout: DefaultTCPLingerTimeout,
+ tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout,
}
}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index e90f9a7d9..0a5534959 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -18,6 +18,7 @@ import (
"container/heap"
"time"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
@@ -49,16 +50,20 @@ type receiver struct {
pendingRcvdSegments segmentHeap
pendingBufUsed seqnum.Size
pendingBufSize seqnum.Size
+
+ // Time when the last ack was received.
+ lastRcvdAckTime time.Time `state:".(unixTime)"`
}
func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver {
return &receiver{
- ep: ep,
- rcvNxt: irs + 1,
- rcvAcc: irs.Add(rcvWnd + 1),
- rcvWnd: rcvWnd,
- rcvWndScale: rcvWndScale,
- pendingBufSize: pendingBufSize,
+ ep: ep,
+ rcvNxt: irs + 1,
+ rcvAcc: irs.Add(rcvWnd + 1),
+ rcvWnd: rcvWnd,
+ rcvWndScale: rcvWndScale,
+ pendingBufSize: pendingBufSize,
+ lastRcvdAckTime: time.Now(),
}
}
@@ -204,15 +209,20 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// Handle ACK (not FIN-ACK, which we handled above) during one of the
// shutdown states.
- if s.flagIsSet(header.TCPFlagAck) {
+ if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt {
r.ep.mu.Lock()
switch r.ep.state {
case StateFinWait1:
r.ep.state = StateFinWait2
+ // Notify protocol goroutine that we have received an
+ // ACK to our FIN so that it can start the FIN_WAIT2
+ // timer to abort connection if the other side does
+ // not close within 2MSL.
+ r.ep.notifyProtocolGoroutine(notifyClose)
case StateClosing:
r.ep.state = StateTimeWait
case StateLastAck:
- r.ep.state = StateClose
+ r.ep.transitionToStateCloseLocked()
}
r.ep.mu.Unlock()
}
@@ -253,25 +263,110 @@ func (r *receiver) updateRTT() {
r.ep.rcvListMu.Unlock()
}
-// handleRcvdSegment handles TCP segments directed at the connection managed by
-// r as they arrive. It is called by the protocol main loop.
-func (r *receiver) handleRcvdSegment(s *segment) {
+func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err *tcpip.Error) {
+ r.ep.rcvListMu.Lock()
+ rcvClosed := r.ep.rcvClosed || r.closed
+ r.ep.rcvListMu.Unlock()
+
+ // If we are in one of the shutdown states then we need to do
+ // additional checks before we try and process the segment.
+ switch state {
+ case StateCloseWait, StateClosing, StateLastAck:
+ if !s.sequenceNumber.LessThanEq(r.rcvNxt) {
+ s.decRef()
+ // Just drop the segment as we have
+ // already received a FIN and this
+ // segment is after the sequence number
+ // for the FIN.
+ return true, nil
+ }
+ fallthrough
+ case StateFinWait1:
+ fallthrough
+ case StateFinWait2:
+ // If we are closed for reads (either due to an
+ // incoming FIN or the user calling shutdown(..,
+ // SHUT_RD) then any data past the rcvNxt should
+ // trigger a RST.
+ endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size()))
+ if rcvClosed && r.rcvNxt.LessThan(endDataSeq) {
+ s.decRef()
+ return true, tcpip.ErrConnectionAborted
+ }
+ if state == StateFinWait1 {
+ break
+ }
+
+ // If it's a retransmission of an old data segment
+ // or a pure ACK then allow it.
+ if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.rcvNxt) ||
+ s.logicalLen() == 0 {
+ break
+ }
+
+ // In FIN-WAIT2 if the socket is fully
+ // closed(not owned by application on our end
+ // then the only acceptable segment is a
+ // FIN. Since FIN can technically also carry
+ // data we verify that the segment carrying a
+ // FIN ends at exactly e.rcvNxt+1.
+ //
+ // From RFC793 page 25.
+ //
+ // For sequence number purposes, the SYN is
+ // considered to occur before the first actual
+ // data octet of the segment in which it occurs,
+ // while the FIN is considered to occur after
+ // the last actual data octet in a segment in
+ // which it occurs.
+ if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) {
+ s.decRef()
+ return true, tcpip.ErrConnectionAborted
+ }
+ }
+
// We don't care about receive processing anymore if the receive side
// is closed.
- if r.closed {
- return
+ //
+ // NOTE: We still want to permit a FIN as it's possible only our
+ // end has closed and the peer is yet to send a FIN. Hence we
+ // compare only the payload.
+ segEnd := s.sequenceNumber.Add(seqnum.Size(s.data.Size()))
+ if rcvClosed && !segEnd.LessThanEq(r.rcvNxt) {
+ return true, nil
+ }
+ return false, nil
+}
+
+// handleRcvdSegment handles TCP segments directed at the connection managed by
+// r as they arrive. It is called by the protocol main loop.
+func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
+ r.ep.mu.RLock()
+ state := r.ep.state
+ closed := r.ep.closed
+ r.ep.mu.RUnlock()
+
+ if state != StateEstablished {
+ drop, err := r.handleRcvdSegmentClosing(s, state, closed)
+ if drop || err != nil {
+ return drop, err
+ }
}
segLen := seqnum.Size(s.data.Size())
segSeq := s.sequenceNumber
// If the sequence number range is outside the acceptable range, just
- // send an ACK. This is according to RFC 793, page 37.
+ // send an ACK and stop further processing of the segment.
+ // This is according to RFC 793, page 68.
if !r.acceptable(segSeq, segLen) {
r.ep.snd.sendAck()
- return
+ return true, nil
}
+ // Store the time of the last ack.
+ r.lastRcvdAckTime = time.Now()
+
// Defer segment processing if it can't be consumed now.
if !r.consumeSegment(s, segSeq, segLen) {
if segLen > 0 || s.flagIsSet(header.TCPFlagFin) {
@@ -288,7 +383,7 @@ func (r *receiver) handleRcvdSegment(s *segment) {
// have to retransmit.
r.ep.snd.sendAck()
}
- return
+ return false, nil
}
// Since we consumed a segment update the receiver's RTT estimate
@@ -315,4 +410,67 @@ func (r *receiver) handleRcvdSegment(s *segment) {
r.pendingBufUsed -= s.logicalLen()
s.decRef()
}
+ return false, nil
+}
+
+// handleTimeWaitSegment handles inbound segments received when the endpoint
+// has entered the TIME_WAIT state.
+func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn bool) {
+ segSeq := s.sequenceNumber
+ segLen := seqnum.Size(s.data.Size())
+
+ // Just silently drop any RST packets in TIME_WAIT. We do not support
+ // TIME_WAIT assasination as a result we confirm w/ fix 1 as described
+ // in https://tools.ietf.org/html/rfc1337#section-3.
+ if s.flagIsSet(header.TCPFlagRst) {
+ return false, false
+ }
+
+ // If it's a SYN and the sequence number is higher than any seen before
+ // for this connection then try and redirect it to a listening endpoint
+ // if available.
+ //
+ // RFC 1122:
+ // "When a connection is [...] on TIME-WAIT state [...]
+ // [a TCP] MAY accept a new SYN from the remote TCP to
+ // reopen the connection directly, if it:
+
+ // (1) assigns its initial sequence number for the new
+ // connection to be larger than the largest sequence
+ // number it used on the previous connection incarnation,
+ // and
+
+ // (2) returns to TIME-WAIT state if the SYN turns out
+ // to be an old duplicate".
+ if s.flagIsSet(header.TCPFlagSyn) && r.rcvNxt.LessThan(segSeq) {
+
+ return false, true
+ }
+
+ // Drop the segment if it does not contain an ACK.
+ if !s.flagIsSet(header.TCPFlagAck) {
+ return false, false
+ }
+
+ // Update Timestamp if required. See RFC7323, section-4.3.
+ if r.ep.sendTSOk && s.parsedOptions.TS {
+ r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.maxSentAck, segSeq)
+ }
+
+ if segSeq.Add(1) == r.rcvNxt && s.flagIsSet(header.TCPFlagFin) {
+ // If it's a FIN-ACK then resetTimeWait and send an ACK, as it
+ // indicates our final ACK could have been lost.
+ r.ep.snd.sendAck()
+ return true, false
+ }
+
+ // If the sequence number range is outside the acceptable range or
+ // carries data then just send an ACK. This is according to RFC 793,
+ // page 37.
+ //
+ // NOTE: In TIME_WAIT the only acceptable sequence number is rcvNxt.
+ if segSeq != r.rcvNxt || segLen != 0 {
+ r.ep.snd.sendAck()
+ }
+ return false, false
}
diff --git a/pkg/tcpip/transport/tcp/rcv_state.go b/pkg/tcpip/transport/tcp/rcv_state.go
new file mode 100644
index 000000000..2bf21a2e7
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rcv_state.go
@@ -0,0 +1,29 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+)
+
+// saveLastRcvdAckTime is invoked by stateify.
+func (r *receiver) saveLastRcvdAckTime() unixTime {
+ return unixTime{r.lastRcvdAckTime.Unix(), r.lastRcvdAckTime.UnixNano()}
+}
+
+// loadLastRcvdAckTime is invoked by stateify.
+func (r *receiver) loadLastRcvdAckTime(unix unixTime) {
+ r.lastRcvdAckTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index ea725d513..1c10da5ca 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -18,6 +18,7 @@ import (
"sync/atomic"
"time"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
@@ -60,13 +61,13 @@ type segment struct {
xmitTime time.Time `state:".(unixTime)"`
}
-func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment {
+func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) *segment {
s := &segment{
refCnt: 1,
id: id,
route: r.Clone(),
}
- s.data = vv.Clone(s.views[:])
+ s.data = pkt.Data.Clone(s.views[:])
s.rcvdTime = time.Now()
return s
}
@@ -99,8 +100,14 @@ func (s *segment) clone() *segment {
return t
}
-func (s *segment) flagIsSet(flag uint8) bool {
- return (s.flags & flag) != 0
+// flagIsSet checks if at least one flag in flags is set in s.flags.
+func (s *segment) flagIsSet(flags uint8) bool {
+ return s.flags&flags != 0
+}
+
+// flagsAreSet checks if all flags in flags are set in s.flags.
+func (s *segment) flagsAreSet(flags uint8) bool {
+ return s.flags&flags == flags
}
func (s *segment) decRef() {
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 8332a0179..8a947dc66 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -28,8 +28,11 @@ import (
)
const (
- // minRTO is the minimum allowed value for the retransmit timeout.
- minRTO = 200 * time.Millisecond
+ // MinRTO is the minimum allowed value for the retransmit timeout.
+ MinRTO = 200 * time.Millisecond
+
+ // MaxRTO is the maximum allowed value for the retransmit timeout.
+ MaxRTO = 120 * time.Second
// InitialCwnd is the initial congestion window.
InitialCwnd = 10
@@ -134,6 +137,10 @@ type sender struct {
// rttMeasureTime is the time when the rttMeasureSeqNum was sent.
rttMeasureTime time.Time `state:".(unixTime)"`
+ // firstRetransmittedSegXmitTime is the original transmit time of
+ // the first segment that was retransmitted due to RTO expiration.
+ firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"`
+
closed bool
writeNext *segment
writeList segmentList
@@ -392,8 +399,8 @@ func (s *sender) updateRTO(rtt time.Duration) {
s.rto = s.rtt.srtt + 4*s.rtt.rttvar
s.rtt.Unlock()
- if s.rto < minRTO {
- s.rto = minRTO
+ if s.rto < MinRTO {
+ s.rto = MinRTO
}
}
@@ -438,8 +445,30 @@ func (s *sender) retransmitTimerExpired() bool {
s.ep.stack.Stats().TCP.Timeouts.Increment()
s.ep.stats.SendErrors.Timeouts.Increment()
- // Give up if we've waited more than a minute since the last resend.
- if s.rto >= 60*time.Second {
+ // Give up if we've waited more than a minute since the last resend or
+ // if a user time out is set and we have exceeded the user specified
+ // timeout since the first retransmission.
+ s.ep.mu.RLock()
+ uto := s.ep.userTimeout
+ s.ep.mu.RUnlock()
+
+ if s.firstRetransmittedSegXmitTime.IsZero() {
+ // We store the original xmitTime of the segment that we are
+ // about to retransmit as the retransmission time. This is
+ // required as by the time the retransmitTimer has expired the
+ // segment has already been sent and unacked for the RTO at the
+ // time the segment was sent.
+ s.firstRetransmittedSegXmitTime = s.writeList.Front().xmitTime
+ }
+
+ elapsed := time.Since(s.firstRetransmittedSegXmitTime)
+ remaining := MaxRTO
+ if uto != 0 {
+ // Cap to the user specified timeout if one is specified.
+ remaining = uto - elapsed
+ }
+
+ if remaining <= 0 || s.rto >= MaxRTO {
return false
}
@@ -447,6 +476,11 @@ func (s *sender) retransmitTimerExpired() bool {
// below.
s.rto *= 2
+ // Cap RTO to remaining time.
+ if s.rto > remaining {
+ s.rto = remaining
+ }
+
// See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4.
//
// Retransmit timeouts:
@@ -1168,6 +1202,8 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// RFC 6298 Rule 5.3
if s.sndUna == s.sndNxt {
s.outstanding = 0
+ // Reset firstRetransmittedSegXmitTime to the zero value.
+ s.firstRetransmittedSegXmitTime = time.Time{}
s.resendTimer.disable()
}
}
diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go
index 12eff8afc..8b20c3455 100644
--- a/pkg/tcpip/transport/tcp/snd_state.go
+++ b/pkg/tcpip/transport/tcp/snd_state.go
@@ -48,3 +48,13 @@ func (s *sender) loadRttMeasureTime(unix unixTime) {
func (s *sender) afterLoad() {
s.resendTimer.init(&s.resendWaker)
}
+
+// saveFirstRetransmittedSegXmitTime is invoked by stateify.
+func (s *sender) saveFirstRetransmittedSegXmitTime() unixTime {
+ return unixTime{s.firstRetransmittedSegXmitTime.Unix(), s.firstRetransmittedSegXmitTime.UnixNano()}
+}
+
+// loadFirstRetransmittedSegXmitTime is invoked by stateify.
+func (s *sender) loadFirstRetransmittedSegXmitTime(unix unixTime) {
+ s.firstRetransmittedSegXmitTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 6d022a266..e8fe4dab5 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -75,6 +75,20 @@ func TestGiveUpConnect(t *testing.T) {
if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted {
t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted)
}
+
+ // Call Connect again to retreive the handshake failure status
+ // and stats updates.
+ if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted {
+ t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrAborted)
+ }
+
+ if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 {
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = 1", got)
+ }
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ }
}
func TestConnectIncrementActiveConnection(t *testing.T) {
@@ -206,17 +220,18 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Send a SYN request.
@@ -256,7 +271,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
case <-ch:
c.EP, _, err = ep.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -264,6 +279,13 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
}
}
+ // Lower stackwide TIME_WAIT timeout so that the reservations
+ // are released instantly on Close.
+ tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil {
+ t.Fatalf("e.stack.SetTransportProtocolOption(%d, %s) = %s", tcp.ProtocolNumber, tcpTW, err)
+ }
+
c.EP.Close()
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
@@ -285,6 +307,11 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
// Get the ACK to the FIN we just sent.
c.GetPacket()
+ // Since an active close was done we need to wait for a little more than
+ // tcpLingerTimeout for the port reservations to be released and the
+ // socket to move to a CLOSED state.
+ time.Sleep(20 * time.Millisecond)
+
// Now resend the same ACK, this ACK should generate a RST as there
// should be no endpoint in SYN-RCVD state and we are not using
// syn-cookies yet. The reason we send the same ACK is we need a valid
@@ -296,8 +323,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck)))
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst)))
}
func TestTCPResetsReceivedIncrement(t *testing.T) {
@@ -376,6 +403,13 @@ func TestConnectResetAfterClose(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ // Set TCPLinger to 3 seconds so that sockets are marked closed
+ // after 3 second in FIN_WAIT2 state.
+ tcpLingerTimeout := 3 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpLingerTimeout, err)
+ }
+
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -396,12 +430,24 @@ func TestConnectResetAfterClose(t *testing.T) {
DstPort: c.Port,
Flags: header.TCPFlagAck,
SeqNum: 790,
- AckNum: c.IRS.Add(1),
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Wait for the ep to give up waiting for a FIN.
+ time.Sleep(tcpLingerTimeout + 1*time.Second)
+
+ // Now send an ACK and it should trigger a RST as the endpoint should
+ // not exist anymore.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(2),
RcvWnd: 30000,
})
- // Wait for the ep to give up waiting for a FIN, and send a RST.
- time.Sleep(3 * time.Second)
for {
b := c.GetPacket()
tcpHdr := header.TCP(header.IPv4(b).Payload())
@@ -413,15 +459,128 @@ func TestConnectResetAfterClose(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ checker.SeqNum(uint32(c.IRS)+2),
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst),
),
)
break
}
}
+// TestClosingWithEnqueuedSegments tests handling of still enqueued segments
+// when the endpoint transitions to StateClose. The in-flight segments would be
+// re-enqueued to a any listening endpoint.
+func TestClosingWithEnqueuedSegments(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ ep := c.EP
+ c.EP = nil
+
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ // Send a FIN for ESTABLISHED --> CLOSED-WAIT
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagFin | header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Get the ACK for the FIN we sent.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ // Close the application endpoint for CLOSE_WAIT --> LAST_ACK
+ ep.Close()
+
+ // Get the FIN
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ // Pause the endpoint`s protocolMainLoop.
+ ep.(interface{ StopWork() }).StopWork()
+
+ // Enqueue last ACK followed by an ACK matching the endpoint
+ //
+ // Send Last ACK for LAST_ACK --> CLOSED
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 791,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Send a packet with ACK set, this would generate RST when
+ // not using SYN cookies as in this test.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 792,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Unpause endpoint`s protocolMainLoop.
+ ep.(interface{ ResumeWork() }).ResumeWork()
+
+ // Wait for the protocolMainLoop to resume and update state.
+ time.Sleep(10 * time.Millisecond)
+
+ // Expect the endpoint to be closed.
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = 1", got)
+ }
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ }
+
+ // Check if the endpoint was moved to CLOSED and netstack a reset in
+ // response to the ACK packet that we sent after last-ACK.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+2),
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst),
+ ),
+ )
+}
+
func TestSimpleReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -474,6 +633,352 @@ func TestSimpleReceive(t *testing.T) {
)
}
+// TestUserSuppliedMSSOnConnectV4 tests that the user supplied MSS is used when
+// creating a new active IPv4 TCP socket. It should be present in the sent TCP
+// SYN segment.
+func TestUserSuppliedMSSOnConnectV4(t *testing.T) {
+ const mtu = 5000
+ const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
+ tests := []struct {
+ name string
+ setMSS uint16
+ expMSS uint16
+ }{
+ {
+ "EqualToMaxMSS",
+ maxMSS,
+ maxMSS,
+ },
+ {
+ "LessThanMTU",
+ maxMSS - 1,
+ maxMSS - 1,
+ },
+ {
+ "GreaterThanMTU",
+ maxMSS + 1,
+ maxMSS,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ // Set the MSS socket option.
+ opt := tcpip.MaxSegOption(test.setMSS)
+ if err := c.EP.SetSockOpt(opt); err != nil {
+ t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err)
+ }
+
+ // Get expected window size.
+ rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
+ t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err)
+ }
+ ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
+
+ // Start connection attempt to IPv4 address.
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet with our user supplied MSS.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws})))
+ })
+ }
+}
+
+// TestUserSuppliedMSSOnConnectV6 tests that the user supplied MSS is used when
+// creating a new active IPv6 TCP socket. It should be present in the sent TCP
+// SYN segment.
+func TestUserSuppliedMSSOnConnectV6(t *testing.T) {
+ const mtu = 5000
+ const maxMSS = mtu - header.IPv6MinimumSize - header.TCPMinimumSize
+ tests := []struct {
+ name string
+ setMSS uint16
+ expMSS uint16
+ }{
+ {
+ "EqualToMaxMSS",
+ maxMSS,
+ maxMSS,
+ },
+ {
+ "LessThanMTU",
+ maxMSS - 1,
+ maxMSS - 1,
+ },
+ {
+ "GreaterThanMTU",
+ maxMSS + 1,
+ maxMSS,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ // Set the MSS socket option.
+ opt := tcpip.MaxSegOption(test.setMSS)
+ if err := c.EP.SetSockOpt(opt); err != nil {
+ t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err)
+ }
+
+ // Get expected window size.
+ rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
+ t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err)
+ }
+ ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
+
+ // Start connection attempt to IPv6 address.
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet with our user supplied MSS.
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws})))
+ })
+ }
+}
+
+func TestSendRstOnListenerRxSynAckV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst),
+ checker.SeqNum(200)))
+}
+
+func TestSendRstOnListenerRxSynAckV6(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst),
+ checker.SeqNum(200)))
+}
+
+// TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete,
+// peers can send data and expect a response within a reasonable ammount of time
+// without calling Accept on the listening endpoint first.
+//
+// This test uses IPv4.
+func TestTCPAckBeforeAcceptV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
+
+ // Send data before accepting the connection.
+ c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+}
+
+// TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete,
+// peers can send data and expect a response within a reasonable ammount of time
+// without calling Accept on the listening endpoint first.
+//
+// This test uses IPv6.
+func TestTCPAckBeforeAcceptV6(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ irs, iss := executeV6Handshake(t, c, context.TestPort, false /* synCookiesInUse */)
+
+ // Send data before accepting the connection.
+ c.SendV6Packet([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+}
+
+func TestSendRstOnListenerRxAckV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1 /* epRcvBuf */)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10 /* backlog */); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagFin | header.TCPFlagAck,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst),
+ checker.SeqNum(200)))
+}
+
+func TestSendRstOnListenerRxAckV6(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true /* v6Only */)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10 /* backlog */); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagFin | header.TCPFlagAck,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst),
+ checker.SeqNum(200)))
+}
+
+// TestListenShutdown tests for the listening endpoint not processing
+// any receive when it is on read shutdown.
+func TestListenShutdown(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1 /* epRcvBuf */)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10 /* backlog */); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
+ t.Fatal("Shutdown failed:", err)
+ }
+
+ // Wait for the endpoint state to be propagated.
+ time.Sleep(10 * time.Millisecond)
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ c.CheckNoPacket("Packet received when listening socket was shutdown")
+}
+
func TestTOSV4(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -635,6 +1140,71 @@ func TestConnectBindToDevice(t *testing.T) {
}
}
+func TestRstOnSynSent(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create an endpoint, don't handshake because we want to interfere with the
+ // handshake process.
+ c.Create(-1)
+
+ // Start connection attempt.
+ waitEntry, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
+ if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted {
+ t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, tcpip.ErrConnectStarted)
+ }
+
+ // Receive SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
+
+ // Ensure that we've reached SynSent state
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ // Send a packet with a proper ACK and a RST flag to cause the socket
+ // to Error and close out
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
+ Flags: header.TCPFlagRst | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ TCPOpts: nil,
+ })
+
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(3 * time.Second):
+ t.Fatal("timed out waiting for packet to arrive")
+ }
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrConnectionRefused)
+ }
+
+ // Due to the RST the endpoint should be in an error state.
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+}
+
func TestOutOfOrderReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -930,8 +1500,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
- // We shouldn't consume a sequence number on RST.
- checker.SeqNum(uint32(c.IRS)+1),
+ checker.SeqNum(uint32(c.IRS)+2),
))
// The RST puts the endpoint into an error state.
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
@@ -1623,7 +2192,7 @@ func TestDelay(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- c.EP.SetSockOpt(tcpip.DelayOption(1))
+ c.EP.SetSockOptInt(tcpip.DelayOption, 1)
var allData []byte
for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
@@ -1671,7 +2240,7 @@ func TestUndelay(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- c.EP.SetSockOpt(tcpip.DelayOption(1))
+ c.EP.SetSockOptInt(tcpip.DelayOption, 1)
allData := [][]byte{{0}, {1, 2, 3}}
for i, data := range allData {
@@ -1704,7 +2273,7 @@ func TestUndelay(t *testing.T) {
// Check that we don't get the second packet yet.
c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond)
- c.EP.SetSockOpt(tcpip.DelayOption(0))
+ c.EP.SetSockOptInt(tcpip.DelayOption, 0)
// Check that data is received.
second := c.GetPacket()
@@ -1741,7 +2310,7 @@ func TestMSSNotDelayed(t *testing.T) {
fn func(tcpip.Endpoint)
}{
{"no-op", func(tcpip.Endpoint) {}},
- {"delay", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.DelayOption(1)) }},
+ {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptInt(tcpip.DelayOption, 1) }},
{"cork", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.CorkOption(1)) }},
}
@@ -2211,6 +2780,13 @@ loop:
if tcp.EndpointState(c.EP.State()) != tcp.StateError {
t.Fatalf("got EP state is not StateError")
}
+
+ if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
+ t.Errorf("got stats.TCP.EstablishedResets.Value() = %v, want = 1", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ }
}
func TestSendOnResetConnection(t *testing.T) {
@@ -2905,6 +3481,13 @@ func TestReadAfterClosedState(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed
+ // after 1 second in TIME_WAIT state.
+ tcpTimeWaitTimeout := 1 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err)
+ }
+
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
@@ -2912,12 +3495,12 @@ func TestReadAfterClosedState(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrWouldBlock)
}
// Shutdown immediately for write, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -2955,10 +3538,9 @@ func TestReadAfterClosedState(t *testing.T) {
),
)
- // Give the stack the chance to transition to closed state. Note that since
- // both the sender and receiver are now closed, we effectively skip the
- // TIME-WAIT state.
- time.Sleep(1 * time.Second)
+ // Give the stack the chance to transition to closed state from
+ // TIME_WAIT.
+ time.Sleep(tcpTimeWaitTimeout * 2)
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want {
t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
@@ -2975,7 +3557,7 @@ func TestReadAfterClosedState(t *testing.T) {
peekBuf := make([]byte, 10)
n, _, err := c.EP.Peek([][]byte{peekBuf})
if err != nil {
- t.Fatalf("Peek failed: %v", err)
+ t.Fatalf("Peek failed: %s", err)
}
peekBuf = peekBuf[:n]
@@ -2986,7 +3568,7 @@ func TestReadAfterClosedState(t *testing.T) {
// Receive data.
v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
if !bytes.Equal(data, v) {
@@ -2996,11 +3578,11 @@ func TestReadAfterClosedState(t *testing.T) {
// Now that we drained the queue, check that functions fail with the
// right error code.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive)
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrClosedForReceive)
}
if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Peek(...) = %v, want = %v", err, tcpip.ErrClosedForReceive)
+ t.Fatalf("got c.EP.Peek(...) = %v, want = %s", err, tcpip.ErrClosedForReceive)
}
}
@@ -3773,8 +4355,9 @@ func TestKeepalive(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ const keepAliveInterval = 10 * time.Millisecond
c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond))
- c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond))
+ c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval))
c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5))
c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1))
@@ -3864,19 +4447,43 @@ func TestKeepalive(t *testing.T) {
)
}
+ // Sleep for a litte over the KeepAlive interval to make sure
+ // the timer has time to fire after the last ACK and close the
+ // close the socket.
+ time.Sleep(keepAliveInterval + 5*time.Millisecond)
+
// The connection should be terminated after 5 unacked keepalives.
+ // Send an ACK to trigger a RST from the stack as the endpoint should
+ // be dead.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
checker.SeqNum(uint32(next)),
- checker.AckNum(uint32(790)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ checker.AckNum(uint32(0)),
+ checker.TCPFlags(header.TCPFlagRst),
),
)
+ if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %v, want = 1", got)
+ }
+
if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
}
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ }
}
func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
@@ -3890,7 +4497,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki
RcvWnd: 30000,
})
- // Receive the SYN-ACK reply.w
+ // Receive the SYN-ACK reply.
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
iss = seqnum.Value(tcp.SequenceNumber())
@@ -3923,6 +4530,50 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki
return irs, iss
}
+func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
+ // Send a SYN request.
+ irs = seqnum.Value(789)
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: srcPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetV6Packet()
+ tcp := header.TCP(header.IPv6(b).Payload())
+ iss = seqnum.Value(tcp.SequenceNumber())
+ tcpCheckers := []checker.TransportChecker{
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(srcPort),
+ checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
+ checker.AckNum(uint32(irs) + 1),
+ }
+
+ if synCookieInUse {
+ // When cookies are in use window scaling is disabled.
+ tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{
+ WS: -1,
+ MSS: c.MSSWithoutOptionsV6(),
+ }))
+ }
+
+ checker.IPv6(t, b, checker.TCP(tcpCheckers...))
+
+ // Send ACK.
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: srcPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
+ return irs, iss
+}
+
// TestListenBacklogFull tests that netstack does not complete handshakes if the
// listen backlog for the endpoint is full.
func TestListenBacklogFull(t *testing.T) {
@@ -4025,6 +4676,210 @@ func TestListenBacklogFull(t *testing.T) {
}
}
+// TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a
+// non unicast IPv4 address are not accepted.
+func TestListenNoAcceptNonUnicastV4(t *testing.T) {
+ multicastAddr := tcpip.Address("\xe0\x00\x01\x02")
+ otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03")
+
+ tests := []struct {
+ name string
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ }{
+ {
+ "SourceUnspecified",
+ header.IPv4Any,
+ context.StackAddr,
+ },
+ {
+ "SourceBroadcast",
+ header.IPv4Broadcast,
+ context.StackAddr,
+ },
+ {
+ "SourceOurMulticast",
+ multicastAddr,
+ context.StackAddr,
+ },
+ {
+ "SourceOtherMulticast",
+ otherMulticastAddr,
+ context.StackAddr,
+ },
+ {
+ "DestUnspecified",
+ context.TestAddr,
+ header.IPv4Any,
+ },
+ {
+ "DestBroadcast",
+ context.TestAddr,
+ header.IPv4Broadcast,
+ },
+ {
+ "DestOurMulticast",
+ context.TestAddr,
+ multicastAddr,
+ },
+ {
+ "DestOtherMulticast",
+ context.TestAddr,
+ otherMulticastAddr,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil {
+ t.Fatalf("JoinGroup failed: %s", err)
+ }
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := c.EP.Listen(1); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ irs := seqnum.Value(789)
+ c.SendPacketWithAddrs(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ }, test.srcAddr, test.dstAddr)
+ c.CheckNoPacket("Should not have received a response")
+
+ // Handle normal packet.
+ c.SendPacketWithAddrs(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ }, context.TestAddr, context.StackAddr)
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.AckNum(uint32(irs)+1)))
+ })
+ }
+}
+
+// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a
+// non unicast IPv6 address are not accepted.
+func TestListenNoAcceptNonUnicastV6(t *testing.T) {
+ multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01")
+ otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02")
+
+ tests := []struct {
+ name string
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ }{
+ {
+ "SourceUnspecified",
+ header.IPv6Any,
+ context.StackV6Addr,
+ },
+ {
+ "SourceAllNodes",
+ header.IPv6AllNodesMulticastAddress,
+ context.StackV6Addr,
+ },
+ {
+ "SourceOurMulticast",
+ multicastAddr,
+ context.StackV6Addr,
+ },
+ {
+ "SourceOtherMulticast",
+ otherMulticastAddr,
+ context.StackV6Addr,
+ },
+ {
+ "DestUnspecified",
+ context.TestV6Addr,
+ header.IPv6Any,
+ },
+ {
+ "DestAllNodes",
+ context.TestV6Addr,
+ header.IPv6AllNodesMulticastAddress,
+ },
+ {
+ "DestOurMulticast",
+ context.TestV6Addr,
+ multicastAddr,
+ },
+ {
+ "DestOtherMulticast",
+ context.TestV6Addr,
+ otherMulticastAddr,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil {
+ t.Fatalf("JoinGroup failed: %s", err)
+ }
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := c.EP.Listen(1); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ irs := seqnum.Value(789)
+ c.SendV6PacketWithAddrs(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ }, test.srcAddr, test.dstAddr)
+ c.CheckNoPacket("Should not have received a response")
+
+ // Handle normal packet.
+ c.SendV6PacketWithAddrs(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ }, context.TestV6Addr, context.StackV6Addr)
+ checker.IPv6(t, c.GetV6Packet(),
+ checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.AckNum(uint32(irs)+1)))
+ })
+ }
+}
+
func TestListenSynRcvdQueueFull(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -4056,7 +4911,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
SrcPort: context.TestPort,
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
- SeqNum: seqnum.Value(789),
+ SeqNum: irs,
RcvWnd: 30000,
})
@@ -4167,7 +5022,8 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
// Send a SYN request.
irs := seqnum.Value(789)
c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
+ // pick a different src port for new SYN.
+ SrcPort: context.TestPort + 1,
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: irs,
@@ -4207,6 +5063,125 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
}
}
+func TestSynRcvdBadSeqNumber(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Start listening.
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state
+ irs := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ iss := seqnum.Value(tcpHdr.SequenceNumber())
+ tcpCheckers := []checker.TransportChecker{
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
+ checker.AckNum(uint32(irs) + 1),
+ }
+ checker.IPv4(t, b, checker.TCP(tcpCheckers...))
+
+ // Now send a packet with an out-of-window sequence number
+ largeSeqnum := irs + seqnum.Value(tcpHdr.WindowSize()) + 1
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: largeSeqnum,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
+
+ // Should receive an ACK with the expected SEQ number
+ b = c.GetPacket()
+ tcpCheckers = []checker.TransportChecker{
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.AckNum(uint32(irs) + 1),
+ checker.SeqNum(uint32(iss + 1)),
+ }
+ checker.IPv4(t, b, checker.TCP(tcpCheckers...))
+
+ // Now that the socket replied appropriately with the ACK,
+ // complete the connection to test that the large SEQ num
+ // did not change the state from SYN-RCVD.
+
+ // Send ACK to move to ESTABLISHED state.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
+
+ newEP, _, err := c.EP.Accept()
+
+ if err != nil && err != tcpip.ErrWouldBlock {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ if err == tcpip.ErrWouldBlock {
+ // Try to accept the connections in the backlog.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ newEP, _, err = c.EP.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ // Now verify that the TCP socket is usable and in a connected state.
+ data := "Don't panic"
+ _, _, err = newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
+
+ if err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ pkt := c.GetPacket()
+ tcpHdr = header.TCP(header.IPv4(pkt).Payload())
+ if string(tcpHdr.Payload()) != data {
+ t.Fatalf("Unexpected data: got %s, want %s", string(tcpHdr.Payload()), data)
+ }
+}
+
func TestPassiveConnectionAttemptIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -4381,6 +5356,9 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
}
+ if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected {
+ t.Errorf("Unexpected error attempting to call connect on an established endpoint, got: %v, want: %v", err, tcpip.ErrAlreadyConnected)
+ }
// Listening endpoint remains in listen state.
if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
@@ -4668,3 +5646,918 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
payloadSize *= 2
}
}
+
+func TestDelayEnabled(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ checkDelayOption(t, c, false, 0) // Delay is disabled by default.
+
+ for _, v := range []struct {
+ delayEnabled tcp.DelayEnabled
+ wantDelayOption int
+ }{
+ {delayEnabled: false, wantDelayOption: 0},
+ {delayEnabled: true, wantDelayOption: 1},
+ } {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil {
+ t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %v", v.delayEnabled, err)
+ }
+ checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption)
+ }
+}
+
+func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption int) {
+ t.Helper()
+
+ var gotDelayEnabled tcp.DelayEnabled
+ if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil {
+ t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %v", err)
+ }
+ if gotDelayEnabled != wantDelayEnabled {
+ t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled)
+ }
+
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue))
+ if err != nil {
+ t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %v", err)
+ }
+ gotDelayOption, err := ep.GetSockOptInt(tcpip.DelayOption)
+ if err != nil {
+ t.Fatalf("ep.GetSockOptInt(tcpip.DelayOption) failed: %v", err)
+ }
+ if gotDelayOption != wantDelayOption {
+ t.Errorf("ep.GetSockOptInt(tcpip.DelayOption) got: %d, want: %d", gotDelayOption, wantDelayOption)
+ }
+}
+
+func TestTCPLingerTimeout(t *testing.T) {
+ c := context.New(t, 1500 /* mtu */)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ testCases := []struct {
+ name string
+ tcpLingerTimeout time.Duration
+ want time.Duration
+ }{
+ {"NegativeLingerTimeout", -123123, 0},
+ {"ZeroLingerTimeout", 0, 0},
+ {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second},
+ // Values > stack's TCPLingerTimeout are capped to the stack's
+ // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds)
+ {"AboveMaxLingerTimeout", 65 * time.Second, 60 * time.Second},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ if err := c.EP.SetSockOpt(tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)); err != nil {
+ t.Fatalf("SetSockOpt(%s) = %s", tc.tcpLingerTimeout, err)
+ }
+ var v tcpip.TCPLingerTimeoutOption
+ if err := c.EP.GetSockOpt(&v); err != nil {
+ t.Fatalf("GetSockOpt(tcpip.TCPLingerTimeoutOption) = %s", err)
+ }
+ if got, want := time.Duration(v), tc.want; got != want {
+ t.Fatalf("unexpected linger timeout got: %s, want: %s", got, want)
+ }
+ })
+ }
+}
+
+func TestTCPTimeWaitRSTIgnored(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Now send a RST and this should be ignored and not
+ // generate an ACK.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagRst,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ })
+
+ c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second)
+
+ // Out of order ACK should generate an immediate ACK in
+ // TIME_WAIT.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 3,
+ })
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+}
+
+func TestTCPTimeWaitOutOfOrder(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Out of order ACK should generate an immediate ACK in
+ // TIME_WAIT.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 3,
+ })
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+}
+
+func TestTCPTimeWaitNewSyn(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Send a SYN request w/ sequence number lower than
+ // the highest sequence number sent. We just reuse
+ // the same number.
+ iss = seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second)
+
+ // Send a SYN request w/ sequence number higher than
+ // the highest sequence number sent.
+ iss = seqnum.Value(792)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b = c.GetPacket()
+ tcpHdr = header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders = &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+}
+
+func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
+ // after 5 seconds in TIME_WAIT state.
+ tcpTimeWaitTimeout := 5 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err)
+ }
+
+ want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ time.Sleep(2 * time.Second)
+
+ // Now send a duplicate FIN. This should cause the TIME_WAIT to extend
+ // by another 5 seconds and also send us a duplicate ACK as it should
+ // indicate that the final ACK was potentially lost.
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Sleep for 4 seconds so at this point we are 1 second past the
+ // original tcpLingerTimeout of 5 seconds.
+ time.Sleep(4 * time.Second)
+
+ // Send an ACK and it should not generate any packet as the socket
+ // should still be in TIME_WAIT for another another 5 seconds due
+ // to the duplicate FIN we sent earlier.
+ *ackHeaders = *finHeaders
+ ackHeaders.SeqNum = ackHeaders.SeqNum + 1
+ ackHeaders.Flags = header.TCPFlagAck
+ c.SendPacket(nil, ackHeaders)
+
+ c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second)
+ // Now sleep for another 2 seconds so that we are past the
+ // extended TIME_WAIT of 7 seconds (2 + 5).
+ time.Sleep(2 * time.Second)
+
+ // Resend the same ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Receive the RST that should be generated as there is no valid
+ // endpoint.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(ackHeaders.AckNum)),
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst)))
+
+ if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %v, want = %v", got, want)
+ }
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
+ }
+}
+
+func TestTCPCloseWithData(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
+ // after 5 seconds in TIME_WAIT state.
+ tcpTimeWaitTimeout := 5 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err)
+ }
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ RcvWnd: 30000,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ // Now trigger a passive close by sending a FIN.
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ RcvWnd: 30000,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Now write a few bytes and then close the endpoint.
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Check that data is received.
+ b = c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
+ t.Errorf("got data = %x, want = %x", p, data)
+ }
+
+ c.EP.Close()
+ // Check the FIN.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))),
+ checker.AckNum(uint32(iss+2)),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ // First send a partial ACK.
+ ackHeaders = &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 2,
+ AckNum: c.IRS + 1 + seqnum.Value(len(data)-1),
+ RcvWnd: 30000,
+ }
+ c.SendPacket(nil, ackHeaders)
+
+ // Now send a full ACK.
+ ackHeaders = &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 2,
+ AckNum: c.IRS + 1 + seqnum.Value(len(data)),
+ RcvWnd: 30000,
+ }
+ c.SendPacket(nil, ackHeaders)
+
+ // Now ACK the FIN.
+ ackHeaders.AckNum++
+ c.SendPacket(nil, ackHeaders)
+
+ // Now send an ACK and we should get a RST back as the endpoint should
+ // be in CLOSED state.
+ ackHeaders = &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 2,
+ AckNum: c.IRS + 1 + seqnum.Value(len(data)),
+ RcvWnd: 30000,
+ }
+ c.SendPacket(nil, ackHeaders)
+
+ // Check the RST.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(ackHeaders.AckNum)),
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst)))
+}
+
+func TestTCPUserTimeout(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
+
+ userTimeout := 50 * time.Millisecond
+ c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout))
+
+ // Send some data and wait before ACKing it.
+ view := buffer.NewView(3)
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+
+ next := uint32(c.IRS) + 1
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ // Wait for a little over the minimum retransmit timeout of 200ms for
+ // the retransmitTimer to fire and close the connection.
+ time.Sleep(tcp.MinRTO + 10*time.Millisecond)
+
+ // No packet should be received as the connection should be silently
+ // closed due to timeout.
+ c.CheckNoPacket("unexpected packet received after userTimeout has expired")
+
+ next += uint32(len(view))
+
+ // The connection should be terminated after userTimeout has expired.
+ // Send an ACK to trigger a RST from the stack as the endpoint should
+ // be dead.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(next)),
+ checker.AckNum(uint32(0)),
+ checker.TCPFlags(header.TCPFlagRst),
+ ),
+ )
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
+ }
+
+ if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want)
+ }
+}
+
+func TestKeepaliveWithUserTimeout(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
+
+ const keepAliveInterval = 10 * time.Millisecond
+ c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond))
+ c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval))
+ c.EP.SetSockOpt(tcpip.KeepaliveCountOption(10))
+ c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1))
+
+ // Set userTimeout to be the duration for 3 keepalive probes.
+ userTimeout := 30 * time.Millisecond
+ c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout))
+
+ // Check that the connection is still alive.
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ }
+
+ // Now receive 2 keepalives, but don't ACK them. The connection should
+ // be reset when the 3rd one should be sent due to userTimeout being
+ // 30ms and each keepalive probe should be sent 10ms apart as set above after
+ // the connection has been idle for 10ms.
+ for i := 0; i < 2; i++ {
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)),
+ checker.AckNum(uint32(790)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ }
+
+ // Sleep for a litte over the KeepAlive interval to make sure
+ // the timer has time to fire after the last ACK and close the
+ // close the socket.
+ time.Sleep(keepAliveInterval + 5*time.Millisecond)
+
+ // The connection should be terminated after 30ms.
+ // Send an ACK to trigger a RST from the stack as the endpoint should
+ // be dead.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seqnum.Value(c.IRS + 1),
+ RcvWnd: 30000,
+ })
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(0)),
+ checker.TCPFlags(header.TCPFlagRst),
+ ),
+ )
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
+ }
+ if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want)
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD
index 19b0d31c5..b33ec2087 100644
--- a/pkg/tcpip/transport/tcp/testing/context/BUILD
+++ b/pkg/tcpip/transport/tcp/testing/context/BUILD
@@ -8,7 +8,7 @@ go_library(
srcs = ["context.go"],
importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context",
visibility = [
- "//:sandbox",
+ "//visibility:public",
],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index ef823e4ae..b0a376eba 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -231,14 +231,15 @@ func (c *Context) CheckNoPacket(errMsg string) {
// addresses. It will fail with an error if no packet is received for
// 2 seconds.
func (c *Context) GetPacket() []byte {
+ c.t.Helper()
select {
case p := <-c.linkEP.C:
if p.Proto != ipv4.ProtocolNumber {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
}
- b := make([]byte, len(p.Header)+len(p.Payload))
- copy(b, p.Header)
- copy(b[len(p.Header):], p.Payload)
+
+ hdr := p.Pkt.Header.View()
+ b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize {
c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize)
@@ -259,14 +260,15 @@ func (c *Context) GetPacket() []byte {
// and destination address. If no packet is available it will return
// nil immediately.
func (c *Context) GetPacketNonBlocking() []byte {
+ c.t.Helper()
select {
case p := <-c.linkEP.C:
if p.Proto != ipv4.ProtocolNumber {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
}
- b := make([]byte, len(p.Header)+len(p.Payload))
- copy(b, p.Header)
- copy(b[len(p.Header):], p.Payload)
+
+ hdr := p.Pkt.Header.View()
+ b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
return b
@@ -302,11 +304,19 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt
copy(icmp[header.ICMPv4PayloadOffset:], p2)
// Inject packet.
- c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView())
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
}
// BuildSegment builds a TCP segment based on the given Headers and payload.
func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView {
+ return c.BuildSegmentWithAddrs(payload, h, TestAddr, StackAddr)
+}
+
+// BuildSegmentWithAddrs builds a TCP segment based on the given Headers,
+// payload and source and destination IPv4 addresses.
+func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) buffer.VectorisedView {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -319,8 +329,8 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView
TotalLength: uint16(len(buf)),
TTL: 65,
Protocol: uint8(tcp.ProtocolNumber),
- SrcAddr: TestAddr,
- DstAddr: StackAddr,
+ SrcAddr: src,
+ DstAddr: dst,
})
ip.SetChecksum(^ip.CalculateChecksum())
@@ -337,7 +347,7 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView
})
// Calculate the TCP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestAddr, StackAddr, uint16(len(t)))
+ xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
// Calculate the TCP checksum and set it.
xsum = header.Checksum(payload, xsum)
@@ -350,13 +360,26 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView
// SendSegment sends a TCP segment that has already been built and written to a
// buffer.VectorisedView.
func (c *Context) SendSegment(s buffer.VectorisedView) {
- c.linkEP.Inject(ipv4.ProtocolNumber, s)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ Data: s,
+ })
}
// SendPacket builds and sends a TCP segment(with the provided payload & TCP
// headers) in an IPv4 packet via the link layer endpoint.
func (c *Context) SendPacket(payload []byte, h *Headers) {
- c.linkEP.Inject(ipv4.ProtocolNumber, c.BuildSegment(payload, h))
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ Data: c.BuildSegment(payload, h),
+ })
+}
+
+// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload
+// & TCPheaders) in an IPv4 packet via the link layer endpoint using the
+// provided source and destination IPv4 addresses.
+func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ Data: c.BuildSegmentWithAddrs(payload, h, src, dst),
+ })
}
// SendAck sends an ACK packet.
@@ -462,14 +485,15 @@ func (c *Context) CreateV6Endpoint(v6only bool) {
// GetV6Packet reads a single packet from the link layer endpoint of the context
// and asserts that it is an IPv6 Packet with the expected src/dest addresses.
func (c *Context) GetV6Packet() []byte {
+ c.t.Helper()
select {
case p := <-c.linkEP.C:
if p.Proto != ipv6.ProtocolNumber {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
}
- b := make([]byte, len(p.Header)+len(p.Payload))
- copy(b, p.Header)
- copy(b[len(p.Header):], p.Payload)
+ b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size())
+ copy(b, p.Pkt.Header.View())
+ copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView())
checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr))
return b
@@ -484,6 +508,13 @@ func (c *Context) GetV6Packet() []byte {
// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of
// the context.
func (c *Context) SendV6Packet(payload []byte, h *Headers) {
+ c.SendV6PacketWithAddrs(payload, h, TestV6Addr, StackV6Addr)
+}
+
+// SendV6PacketWithAddrs builds and sends an IPv6 Packet via the link layer
+// endpoint of the context using the provided source and destination IPv6
+// addresses.
+func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -494,8 +525,8 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) {
PayloadLength: uint16(header.TCPMinimumSize + len(payload)),
NextHeader: uint8(tcp.ProtocolNumber),
HopLimit: 65,
- SrcAddr: TestV6Addr,
- DstAddr: StackV6Addr,
+ SrcAddr: src,
+ DstAddr: dst,
})
// Initialize the TCP header.
@@ -511,14 +542,16 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) {
})
// Calculate the TCP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestV6Addr, StackV6Addr, uint16(len(t)))
+ xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
// Calculate the TCP checksum and set it.
xsum = header.Checksum(payload, xsum)
t.SetChecksum(^t.CalculateChecksum(xsum))
// Inject packet.
- c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
}
// CreateConnected creates a connected TCP endpoint.
@@ -1059,3 +1092,9 @@ func (c *Context) SetGSOEnabled(enable bool) {
func (c *Context) MSSWithoutOptions() uint16 {
return uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize)
}
+
+// MSSWithoutOptionsV6 returns the value for the MSS used by the stack when no
+// options are in use for IPv6 packets.
+func (c *Context) MSSWithoutOptionsV6() uint16 {
+ return uint16(c.linkEP.MTU() - header.IPv6MinimumSize - header.TCPMinimumSize)
+}
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index c9460aa0d..97e4d5825 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -34,6 +34,7 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/iptables",
+ "//pkg/tcpip/ports",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
"//pkg/waiter",
@@ -59,11 +60,3 @@ go_test(
"//pkg/waiter",
],
)
-
-filegroup(
- name = "autogen",
- srcs = [
- "udp_packet_list.go",
- ],
- visibility = ["//:sandbox"],
-)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 6e87245b7..1ac4705af 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -31,9 +32,6 @@ type udpPacket struct {
senderAddress tcpip.FullAddress
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
timestamp int64
- // views is used as buffer for data when its length is large
- // enough to store a VectorisedView.
- views [8]buffer.View `state:"nosave"`
}
// EndpointState represents the state of a UDP endpoint.
@@ -80,6 +78,7 @@ type endpoint struct {
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
+ uniqueID uint64
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -106,6 +105,11 @@ type endpoint struct {
bindToDevice tcpip.NICID
broadcast bool
+ // Values used to reserve a port or register a transport endpoint.
+ // (which ever happens first).
+ boundBindToDevice tcpip.NICID
+ boundPortFlags ports.Flags
+
// sendTOS represents IPv4 TOS or IPv6 TrafficClass,
// applied while sending packets. Defaults to 0 as on Linux.
sendTOS uint8
@@ -140,7 +144,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
stack: s,
TransportEndpointInfo: stack.TransportEndpointInfo{
NetProto: netProto,
- TransProto: header.TCPProtocolNumber,
+ TransProto: header.UDPProtocolNumber,
},
waiterQueue: waiterQueue,
// RFC 1075 section 5.4 recommends a TTL of 1 for membership
@@ -160,9 +164,15 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
state: StateInitial,
+ uniqueID: s.UniqueID(),
}
}
+// UniqueID implements stack.TransportEndpoint.UniqueID.
+func (e *endpoint) UniqueID() uint64 {
+ return e.uniqueID
+}
+
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
@@ -171,8 +181,10 @@ func (e *endpoint) Close() {
switch e.state {
case StateBound, StateConnected:
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
+ e.boundBindToDevice = 0
+ e.boundPortFlags = ports.Flags{}
}
for _, mem := range e.multicastMemberships {
@@ -278,7 +290,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// connectRoute establishes a route to the specified interface or the
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
-func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
+func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
localAddr := e.ID.LocalAddress
if isBroadcastOrMulticast(localAddr) {
// A packet can only originate from a unicast address (i.e., an interface).
@@ -286,20 +298,20 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr
}
if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
- if nicid == 0 {
- nicid = e.multicastNICID
+ if nicID == 0 {
+ nicID = e.multicastNICID
}
- if localAddr == "" && nicid == 0 {
+ if localAddr == "" && nicID == 0 {
localAddr = e.multicastAddr
}
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto, e.multicastLoop)
+ r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop)
if err != nil {
return stack.Route{}, 0, err
}
- return r, nicid, nil
+ return r, nicID, nil
}
// Write writes data to the endpoint's peer. This method does not block
@@ -378,13 +390,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
} else {
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
- nicid := to.NIC
+ nicID := to.NIC
if e.BindNICID != 0 {
- if nicid != 0 && nicid != e.BindNICID {
+ if nicID != 0 && nicID != e.BindNICID {
return 0, nil, tcpip.ErrNoRoute
}
- nicid = e.BindNICID
+ nicID = e.BindNICID
}
if to.Addr == header.IPv4Broadcast && !e.broadcast {
@@ -396,7 +408,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, err
}
- r, _, err := e.connectRoute(nicid, *to, netProto)
+ r, _, err := e.connectRoute(nicID, *to, netProto)
if err != nil {
return 0, nil, err
}
@@ -618,9 +630,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.bindToDevice = 0
return nil
}
- for nicid, nic := range e.stack.NICInfo() {
+ for nicID, nic := range e.stack.NICInfo() {
if nic.Name == string(v) {
- e.bindToDevice = nicid
+ e.bindToDevice = nicID
return nil
}
}
@@ -728,6 +740,10 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = tcpip.MulticastLoopOption(v)
return nil
+ case *tcpip.ReuseAddressOption:
+ *o = 0
+ return nil
+
case *tcpip.ReusePortOption:
e.mu.RLock()
v := e.reusePort
@@ -809,7 +825,11 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
if useDefaultTTL {
ttl = r.DefaultTTL()
}
- if err := r.WritePacket(nil /* gso */, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil {
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: data,
+ TransportHeader: buffer.View(udp),
+ }); err != nil {
r.Stats().UDP.PacketSendErrors.Increment()
return err
}
@@ -859,7 +879,10 @@ func (e *endpoint) Disconnect() *tcpip.Error {
if e.state != StateConnected {
return nil
}
- id := stack.TransportEndpointID{}
+ var (
+ id stack.TransportEndpointID
+ btd tcpip.NICID
+ )
// Exclude ephemerally bound endpoints.
if e.BindNICID != 0 || e.ID.LocalAddress == "" {
var err *tcpip.Error
@@ -867,7 +890,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
LocalPort: e.ID.LocalPort,
LocalAddress: e.ID.LocalAddress,
}
- id, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
+ id, btd, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
if err != nil {
return err
}
@@ -875,13 +898,15 @@ func (e *endpoint) Disconnect() *tcpip.Error {
} else {
if e.ID.LocalPort != 0 {
// Release the ephemeral port.
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice)
+ e.boundPortFlags = ports.Flags{}
}
e.state = StateInitial
}
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
e.ID = id
+ e.boundBindToDevice = btd
e.route.Release()
e.route = stack.Route{}
e.dstPort = 0
@@ -903,7 +928,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- nicid := addr.NIC
+ nicID := addr.NIC
var localPort uint16
switch e.state {
case StateInitial:
@@ -913,16 +938,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
break
}
- if nicid != 0 && nicid != e.BindNICID {
+ if nicID != 0 && nicID != e.BindNICID {
return tcpip.ErrInvalidEndpointState
}
- nicid = e.BindNICID
+ nicID = e.BindNICID
default:
return tcpip.ErrInvalidEndpointState
}
- r, nicid, err := e.connectRoute(nicid, addr, netProto)
+ r, nicID, err := e.connectRoute(nicID, addr, netProto)
if err != nil {
return err
}
@@ -950,20 +975,21 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
}
- id, err = e.registerWithStack(nicid, netProtos, id)
+ id, btd, err := e.registerWithStack(nicID, netProtos, id)
if err != nil {
return err
}
// Remove the old registration.
if e.ID.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
}
e.ID = id
+ e.boundBindToDevice = btd
e.route = r.Clone()
e.dstPort = addr.Port
- e.RegisterNICID = nicid
+ e.RegisterNICID = nicID
e.effectiveNetProtos = netProtos
e.state = StateConnected
@@ -1018,20 +1044,27 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
-func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
+func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) {
if e.ID.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice)
+ flags := ports.Flags{
+ LoadBalanced: e.reusePort,
+ // FIXME(b/129164367): Support SO_REUSEADDR.
+ MostRecent: false,
+ }
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, flags, e.bindToDevice)
if err != nil {
- return id, err
+ return id, e.bindToDevice, err
}
+ e.boundPortFlags = flags
id.LocalPort = port
}
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice)
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice)
+ e.boundPortFlags = ports.Flags{}
}
- return id, err
+ return id, e.bindToDevice, err
}
func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
@@ -1057,11 +1090,11 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
}
}
- nicid := addr.NIC
+ nicID := addr.NIC
if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) {
// A local unicast address was specified, verify that it's valid.
- nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
- if nicid == 0 {
+ nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
+ if nicID == 0 {
return tcpip.ErrBadLocalAddress
}
}
@@ -1070,13 +1103,14 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
LocalPort: addr.Port,
LocalAddress: addr.Addr,
}
- id, err = e.registerWithStack(nicid, netProtos, id)
+ id, btd, err := e.registerWithStack(nicID, netProtos, id)
if err != nil {
return err
}
e.ID = id
- e.RegisterNICID = nicid
+ e.boundBindToDevice = btd
+ e.RegisterNICID = nicID
e.effectiveNetProtos = netProtos
// Mark endpoint as bound.
@@ -1111,9 +1145,14 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
+ addr := e.ID.LocalAddress
+ if e.state == StateConnected {
+ addr = e.route.LocalAddress
+ }
+
return tcpip.FullAddress{
NIC: e.RegisterNICID,
- Addr: e.ID.LocalAddress,
+ Addr: addr,
Port: e.ID.LocalPort,
}, nil
}
@@ -1154,17 +1193,17 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
// Get the header then trim it from the view.
- hdr := header.UDP(vv.First())
- if int(hdr.Length()) > vv.Size() {
+ hdr := header.UDP(pkt.Data.First())
+ if int(hdr.Length()) > pkt.Data.Size() {
// Malformed packet.
e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
- vv.TrimFront(header.UDPMinimumSize)
+ pkt.Data.TrimFront(header.UDPMinimumSize)
e.rcvMu.Lock()
e.stack.Stats().UDP.PacketsReceived.Increment()
@@ -1188,18 +1227,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
wasEmpty := e.rcvBufSize == 0
// Push new packet into receive list and increment the buffer size.
- pkt := &udpPacket{
+ packet := &udpPacket{
senderAddress: tcpip.FullAddress{
NIC: r.NICID(),
Addr: id.RemoteAddress,
Port: hdr.SourcePort(),
},
}
- pkt.data = vv.Clone(pkt.views[:])
- e.rcvList.PushBack(pkt)
- e.rcvBufSize += vv.Size()
+ packet.data = pkt.Data
+ e.rcvList.PushBack(packet)
+ e.rcvBufSize += pkt.Data.Size()
- pkt.timestamp = e.stack.NowNanoseconds()
+ packet.timestamp = e.stack.NowNanoseconds()
e.rcvMu.Unlock()
@@ -1210,7 +1249,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
}
// State implements tcpip.Endpoint.State.
@@ -1234,6 +1273,9 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
return &e.stats
}
+// Wait implements tcpip.Endpoint.Wait.
+func (*endpoint) Wait() {}
+
func isBroadcastOrMulticast(a tcpip.Address) bool {
return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a)
}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index b227e353b..43fb047ed 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -109,7 +109,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
// pass it to the reservation machinery.
id := e.ID
e.ID.LocalPort = 0
- e.ID, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
+ e.ID, e.boundBindToDevice, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
if err != nil {
panic(err)
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index d399ec722..fc706ede2 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -16,7 +16,6 @@ package udp
import (
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -44,12 +43,12 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder {
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool {
f.handler(&ForwarderRequest{
stack: f.stack,
route: r,
id: id,
- vv: vv,
+ pkt: pkt,
})
return true
@@ -62,7 +61,7 @@ type ForwarderRequest struct {
stack *stack.Stack
route *stack.Route
id stack.TransportEndpointID
- vv buffer.VectorisedView
+ pkt tcpip.PacketBuffer
}
// ID returns the 4-tuple (src address, src port, dst address, dst port) that
@@ -90,7 +89,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
ep.rcvReady = true
ep.rcvMu.Unlock()
- ep.HandlePacket(r.route, r.id, r.vv)
+ ep.HandlePacket(r.route, r.id, r.pkt)
return ep, nil
}
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index de026880f..259c3072a 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -66,10 +66,10 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
+func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool {
// Get the header then trim it from the view.
- hdr := header.UDP(vv.First())
- if int(hdr.Length()) > vv.Size() {
+ hdr := header.UDP(pkt.Data.First())
+ if int(hdr.Length()) > pkt.Data.Size() {
// Malformed packet.
r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
return true
@@ -116,13 +116,18 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
}
headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
available := int(mtu) - headerLen
- payloadLen := len(netHeader) + vv.Size()
+ payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size()
if payloadLen > available {
payloadLen = available
}
- payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader})
- payload.Append(vv)
+ // The buffers used by pkt may be used elsewhere in the system.
+ // For example, a raw or packet socket may use what UDP
+ // considers an unreachable destination. Thus we deep copy pkt
+ // to prevent multiple ownership and SR errors.
+ newNetHeader := append(buffer.View(nil), pkt.NetworkHeader...)
+ payload := newNetHeader.ToVectorisedView()
+ payload.Append(pkt.Data.ToView().ToVectorisedView())
payload.CapLength(payloadLen)
hdr := buffer.NewPrependable(headerLen)
@@ -130,7 +135,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
pkt.SetType(header.ICMPv4DstUnreachable)
pkt.SetCode(header.ICMPv4PortUnreachable)
pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload))
- r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS})
+ r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: payload,
+ })
case header.IPv6AddressSize:
if !r.Stack().AllowICMPMessage() {
@@ -151,12 +159,12 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
}
headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize
available := int(mtu) - headerLen
- payloadLen := len(netHeader) + vv.Size()
+ payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size()
if payloadLen > available {
payloadLen = available
}
- payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader})
- payload.Append(vv)
+ payload := buffer.NewVectorisedView(len(pkt.NetworkHeader), []buffer.View{pkt.NetworkHeader})
+ payload.Append(pkt.Data)
payload.CapLength(payloadLen)
hdr := buffer.NewPrependable(headerLen)
@@ -164,7 +172,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
pkt.SetType(header.ICMPv6DstUnreachable)
pkt.SetCode(header.ICMPv6PortUnreachable)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload))
- r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS})
+ r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ Header: hdr,
+ Data: payload,
+ })
}
return true
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index b724d788c..7051a7a9c 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -356,9 +356,9 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw
if p.Proto != flow.netProto() {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
}
- b := make([]byte, len(p.Header)+len(p.Payload))
- copy(b, p.Header)
- copy(b[len(p.Header):], p.Payload)
+
+ hdr := p.Pkt.Header.View()
+ b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
h := flow.header4Tuple(outgoing)
checkers := append(
@@ -397,7 +397,8 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) {
func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
- copy(buf[len(buf)-len(payload):], payload)
+ payloadStart := len(buf) - len(payload)
+ copy(buf[payloadStart:], payload)
// Initialize the IP header.
ip := header.IPv6(buf)
@@ -431,7 +432,11 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ NetworkHeader: buffer.View(ip),
+ TransportHeader: buffer.View(u),
+ })
}
// injectV4Packet creates a V4 test packet with the given payload and header
@@ -441,7 +446,8 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
- copy(buf[len(buf)-len(payload):], payload)
+ payloadStart := len(buf) - len(payload)
+ copy(buf[payloadStart:], payload)
// Initialize the IP header.
ip := header.IPv4(buf)
@@ -471,7 +477,12 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView())
+
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ NetworkHeader: buffer.View(ip),
+ TransportHeader: buffer.View(u),
+ })
}
func newPayload() []byte {
@@ -1442,8 +1453,8 @@ func TestV4UnknownDestination(t *testing.T) {
select {
case p := <-c.linkEP.C:
var pkt []byte
- pkt = append(pkt, p.Header...)
- pkt = append(pkt, p.Payload...)
+ pkt = append(pkt, p.Pkt.Header.View()...)
+ pkt = append(pkt, p.Pkt.Data.ToView()...)
if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
}
@@ -1516,8 +1527,8 @@ func TestV6UnknownDestination(t *testing.T) {
select {
case p := <-c.linkEP.C:
var pkt []byte
- pkt = append(pkt, p.Header...)
- pkt = append(pkt, p.Payload...)
+ pkt = append(pkt, p.Pkt.Header.View()...)
+ pkt = append(pkt, p.Pkt.Data.ToView()...)
if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
}
diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD
index 1f7efb064..0427bc41f 100644
--- a/pkg/waiter/BUILD
+++ b/pkg/waiter/BUILD
@@ -34,11 +34,3 @@ go_test(
],
embed = [":waiter"],
)
-
-filegroup(
- name = "autogen",
- srcs = [
- "waiter_list.go",
- ],
- visibility = ["//:sandbox"],
-)