summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--README.md5
-rw-r--r--WORKSPACE23
-rw-r--r--g3doc/README.md4
-rw-r--r--g3doc/user_guide/compatibility.md5
-rw-r--r--g3doc/user_guide/install.md9
-rw-r--r--images/benchmarks/node/package-lock.json1025
-rw-r--r--images/benchmarks/node/package.json2
-rw-r--r--nogo.yaml6
-rw-r--r--pkg/abi/linux/file.go4
-rw-r--r--pkg/abi/linux/ioctl.go10
-rw-r--r--pkg/abi/linux/socket.go13
-rw-r--r--pkg/atomicbitops/atomicbitops_amd64.s40
-rw-r--r--pkg/atomicbitops/atomicbitops_arm64.s16
-rw-r--r--pkg/atomicbitops/atomicbitops_noasm.go8
-rw-r--r--pkg/buffer/view_test.go18
-rw-r--r--pkg/crypto/crypto_stdlib.go15
-rw-r--r--pkg/eventfd/BUILD22
-rw-r--r--pkg/eventfd/eventfd.go115
-rw-r--r--pkg/eventfd/eventfd_test.go75
-rw-r--r--pkg/lisafs/BUILD117
-rw-r--r--pkg/lisafs/README.md3
-rw-r--r--pkg/lisafs/channel.go190
-rw-r--r--pkg/lisafs/client.go432
-rw-r--r--pkg/lisafs/client_file.go528
-rw-r--r--pkg/lisafs/communicator.go80
-rw-r--r--pkg/lisafs/connection.go320
-rw-r--r--pkg/lisafs/connection_test.go194
-rw-r--r--pkg/lisafs/fd.go374
-rw-r--r--pkg/lisafs/handlers.go768
-rw-r--r--pkg/lisafs/lisafs.go18
-rw-r--r--pkg/lisafs/message.go1251
-rw-r--r--pkg/lisafs/sample_message.go110
-rw-r--r--pkg/lisafs/server.go113
-rw-r--r--pkg/lisafs/sock.go208
-rw-r--r--pkg/lisafs/sock_test.go217
-rw-r--r--pkg/lisafs/testsuite/BUILD20
-rw-r--r--pkg/lisafs/testsuite/testsuite.go637
-rw-r--r--pkg/p9/client.go2
-rw-r--r--pkg/ring0/defs.go12
-rw-r--r--pkg/ring0/defs_amd64.go5
-rw-r--r--pkg/ring0/entry_amd64.go5
-rw-r--r--pkg/ring0/entry_amd64.s177
-rw-r--r--pkg/ring0/kernel.go5
-rw-r--r--pkg/ring0/kernel_amd64.go23
-rw-r--r--pkg/ring0/lib_amd64.go42
-rw-r--r--pkg/ring0/lib_amd64.s23
-rw-r--r--pkg/ring0/offsets_amd64.go3
-rw-r--r--pkg/ring0/pagetables/pagetables.go9
-rw-r--r--pkg/safecopy/BUILD2
-rw-r--r--pkg/safecopy/safecopy.go5
-rw-r--r--pkg/safecopy/safecopy_unsafe.go37
-rw-r--r--pkg/sentry/control/pprof.go30
-rw-r--r--pkg/sentry/fs/fdpipe/pipe.go3
-rw-r--r--pkg/sentry/fs/file_overlay.go12
-rw-r--r--pkg/sentry/fs/fsutil/file.go4
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper.go21
-rw-r--r--pkg/sentry/fs/host/inode.go2
-rw-r--r--pkg/sentry/fs/inotify.go2
-rw-r--r--pkg/sentry/fs/lock/lock.go2
-rw-r--r--pkg/sentry/fs/proc/sys.go3
-rw-r--r--pkg/sentry/fs/timerfd/timerfd.go2
-rw-r--r--pkg/sentry/fs/tty/line_discipline.go4
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/BUILD13
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/bitmap.go139
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/bitmap_test.go99
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/cgroupfs.go2
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/cpuset.go114
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD3
-rw-r--r--pkg/sentry/fsimpl/gofer/directory.go101
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go515
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go677
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer_test.go1
-rw-r--r--pkg/sentry/fsimpl/gofer/handle.go80
-rw-r--r--pkg/sentry/fsimpl/gofer/p9file.go8
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go26
-rw-r--r--pkg/sentry/fsimpl/gofer/revalidate.go50
-rw-r--r--pkg/sentry/fsimpl/gofer/save_restore.go151
-rw-r--r--pkg/sentry/fsimpl/gofer/socket.go45
-rw-r--r--pkg/sentry/fsimpl/gofer/special_file.go18
-rw-r--r--pkg/sentry/fsimpl/gofer/symlink.go8
-rw-r--r--pkg/sentry/fsimpl/gofer/time.go5
-rw-r--r--pkg/sentry/fsimpl/mqfs/BUILD2
-rw-r--r--pkg/sentry/fsimpl/mqfs/mqfs.go4
-rw-r--r--pkg/sentry/fsimpl/mqfs/registry.go6
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go2
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go2
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_files.go5
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go9
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go14
-rw-r--r--pkg/sentry/fsimpl/sys/sys_test.go14
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go14
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go2
-rw-r--r--pkg/sentry/hostmm/BUILD3
-rw-r--r--pkg/sentry/hostmm/hostmm.go42
-rw-r--r--pkg/sentry/inet/BUILD13
-rw-r--r--pkg/sentry/kernel/BUILD3
-rw-r--r--pkg/sentry/kernel/epoll/epoll.go15
-rw-r--r--pkg/sentry/kernel/epoll/epoll_state.go2
-rw-r--r--pkg/sentry/kernel/eventfd/eventfd.go2
-rw-r--r--pkg/sentry/kernel/ipc/BUILD2
-rw-r--r--pkg/sentry/kernel/kernel.go46
-rw-r--r--pkg/sentry/kernel/mq/mq.go4
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go2
-rw-r--r--pkg/sentry/kernel/task.go6
-rw-r--r--pkg/sentry/kernel/task_clone.go2
-rw-r--r--pkg/sentry/kernel/task_log.go12
-rw-r--r--pkg/sentry/kernel/task_net.go12
-rw-r--r--pkg/sentry/kernel/task_start.go2
-rw-r--r--pkg/sentry/kernel/threads.go6
-rw-r--r--pkg/sentry/mm/syscalls.go2
-rw-r--r--pkg/sentry/platform/kvm/BUILD24
-rw-r--r--pkg/sentry/platform/kvm/bluepill.go7
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64.go8
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64.s27
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.go10
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.s34
-rw-r--r--pkg/sentry/platform/kvm/bluepill_unsafe.go32
-rw-r--r--pkg/sentry/platform/kvm/kvm.go6
-rw-r--r--pkg/sentry/platform/kvm/kvm_safecopy_test.go104
-rw-r--r--pkg/sentry/platform/kvm/machine.go136
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go36
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64_unsafe.go12
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go125
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go15
-rw-r--r--pkg/sentry/platform/kvm/machine_unsafe.go43
-rw-r--r--pkg/sentry/platform/kvm/physical_map.go3
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil_arm64.go4
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil_arm64.s35
-rw-r--r--pkg/sentry/seccheck/BUILD4
-rw-r--r--pkg/sentry/seccheck/execve.go65
-rw-r--r--pkg/sentry/seccheck/exit.go57
-rw-r--r--pkg/sentry/seccheck/seccheck.go26
-rw-r--r--pkg/sentry/socket/BUILD5
-rw-r--r--pkg/sentry/socket/control/control.go32
-rw-r--r--pkg/sentry/socket/control/control_test.go2
-rw-r--r--pkg/sentry/socket/hostinet/socket.go28
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go4
-rw-r--r--pkg/sentry/socket/netfilter/targets.go2
-rw-r--r--pkg/sentry/socket/netstack/BUILD1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go137
-rw-r--r--pkg/sentry/socket/netstack/netstack_state.go (renamed from pkg/tcpip/stack/iptables_state.go)29
-rw-r--r--pkg/sentry/socket/netstack/stack.go2
-rw-r--r--pkg/sentry/socket/socket.go37
-rw-r--r--pkg/sentry/socket/socket_state.go27
-rw-r--r--pkg/sentry/socket/unix/transport/queue.go8
-rw-r--r--pkg/sentry/strace/strace.go4
-rw-r--r--pkg/sentry/time/sampler_arm64.go4
-rw-r--r--pkg/sentry/usage/memory.go6
-rw-r--r--pkg/sentry/vfs/epoll.go6
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go4
-rw-r--r--pkg/sentry/vfs/resolving_path.go6
-rw-r--r--pkg/sentry/vfs/save_restore.go29
-rw-r--r--pkg/shim/service.go20
-rw-r--r--pkg/shim/service_test.go20
-rw-r--r--pkg/sighandling/BUILD (renamed from pkg/sentry/sighandling/BUILD)2
-rw-r--r--pkg/sighandling/sighandling.go (renamed from pkg/sentry/sighandling/sighandling.go)0
-rw-r--r--pkg/sighandling/sighandling_unsafe.go (renamed from pkg/sentry/sighandling/sighandling_unsafe.go)34
-rw-r--r--pkg/sync/BUILD1
-rw-r--r--pkg/sync/atomicptr/generic_atomicptr_unsafe.go2
-rw-r--r--pkg/sync/wait.go58
-rw-r--r--pkg/tcpip/BUILD1
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go24
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go112
-rw-r--r--pkg/tcpip/checker/checker.go13
-rw-r--r--pkg/tcpip/header/ipv4.go9
-rw-r--r--pkg/tcpip/header/parse/parse.go68
-rw-r--r--pkg/tcpip/link/pipe/pipe.go4
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go16
-rw-r--r--pkg/tcpip/link/sharedmem/BUILD31
-rw-r--r--pkg/tcpip/link/sharedmem/queue/rx.go2
-rw-r--r--pkg/tcpip/link/sharedmem/queuepair.go199
-rw-r--r--pkg/tcpip/link/sharedmem/rx.go30
-rw-r--r--pkg/tcpip/link/sharedmem/server_rx.go142
-rw-r--r--pkg/tcpip/link/sharedmem/server_tx.go175
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go230
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_server.go344
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_server_test.go220
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go114
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_unsafe.go33
-rw-r--r--pkg/tcpip/link/sharedmem/tx.go32
-rw-r--r--pkg/tcpip/network/arp/arp.go1
-rw-r--r--pkg/tcpip/network/arp/arp_test.go16
-rw-r--r--pkg/tcpip/network/ip_test.go86
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go124
-rw-r--r--pkg/tcpip/network/ipv4/igmp_test.go28
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go114
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go218
-rw-r--r--pkg/tcpip/network/ipv6/BUILD1
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go166
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go71
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go137
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go243
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go33
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go6
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go48
-rw-r--r--pkg/tcpip/network/multicast_group_test.go21
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go8
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go16
-rw-r--r--pkg/tcpip/socketops.go18
-rw-r--r--pkg/tcpip/stack/BUILD2
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go24
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state_test.go4
-rw-r--r--pkg/tcpip/stack/conntrack.go685
-rw-r--r--pkg/tcpip/stack/conntrack_test.go132
-rw-r--r--pkg/tcpip/stack/forwarding_test.go26
-rw-r--r--pkg/tcpip/stack/icmp_rate_limit.go39
-rw-r--r--pkg/tcpip/stack/iptables.go216
-rw-r--r--pkg/tcpip/stack/iptables_targets.go201
-rw-r--r--pkg/tcpip/stack/iptables_types.go28
-rw-r--r--pkg/tcpip/stack/ndp_test.go107
-rw-r--r--pkg/tcpip/stack/nic.go43
-rw-r--r--pkg/tcpip/stack/nic_test.go5
-rw-r--r--pkg/tcpip/stack/packet_buffer.go69
-rw-r--r--pkg/tcpip/stack/packet_buffer_test.go38
-rw-r--r--pkg/tcpip/stack/registration.go29
-rw-r--r--pkg/tcpip/stack/stack.go57
-rw-r--r--pkg/tcpip/stack/stack_test.go642
-rw-r--r--pkg/tcpip/stack/tcp.go6
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go48
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go18
-rw-r--r--pkg/tcpip/stack/transport_test.go40
-rw-r--r--pkg/tcpip/tcpip.go54
-rw-r--r--pkg/tcpip/tcpip_state.go27
-rw-r--r--pkg/tcpip/tests/integration/BUILD26
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go20
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go695
-rw-r--r--pkg/tcpip/tests/integration/istio_test.go365
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go24
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go47
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go20
-rw-r--r--pkg/tcpip/tests/integration/route_test.go38
-rw-r--r--pkg/tcpip/tests/utils/utils.go68
-rw-r--r--pkg/tcpip/transport/icmp/BUILD2
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go438
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go35
-rw-r--r--pkg/tcpip/transport/icmp/icmp_test.go8
-rw-r--r--pkg/tcpip/transport/internal/network/BUILD1
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint.go177
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint_state.go12
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint_test.go33
-rw-r--r--pkg/tcpip/transport/internal/noop/BUILD14
-rw-r--r--pkg/tcpip/transport/internal/noop/endpoint.go172
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go120
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go18
-rw-r--r--pkg/tcpip/transport/raw/BUILD1
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go44
-rw-r--r--pkg/tcpip/transport/raw/protocol.go16
-rw-r--r--pkg/tcpip/transport/tcp/BUILD13
-rw-r--r--pkg/tcpip/transport/tcp/accept.go346
-rw-r--r--pkg/tcpip/transport/tcp/connect.go9
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go115
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go8
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go12
-rw-r--r--pkg/tcpip/transport/tcp/rcv_test.go2
-rw-r--r--pkg/tcpip/transport/tcp/segment_test.go2
-rw-r--r--pkg/tcpip/transport/tcp/snd.go127
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go15
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go255
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go110
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go8
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go27
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go123
-rw-r--r--pkg/unet/BUILD1
-rw-r--r--pkg/unet/unet.go20
-rw-r--r--pkg/unet/unet_unsafe.go2
-rw-r--r--runsc/boot/BUILD3
-rw-r--r--runsc/boot/filter/config.go18
-rw-r--r--runsc/boot/fs.go14
-rw-r--r--runsc/boot/loader.go47
-rw-r--r--runsc/boot/loader_test.go11
-rw-r--r--runsc/boot/network.go13
-rw-r--r--runsc/boot/profile.go95
-rw-r--r--runsc/boot/vfs.go6
-rw-r--r--runsc/cgroup/BUILD1
-rw-r--r--runsc/cgroup/cgroup.go99
-rw-r--r--runsc/cgroup/cgroup_test.go36
-rw-r--r--runsc/cli/main.go3
-rw-r--r--runsc/cmd/boot.go50
-rw-r--r--runsc/cmd/debug.go95
-rw-r--r--runsc/cmd/gofer.go120
-rw-r--r--runsc/cmd/run.go9
-rw-r--r--runsc/cmd/spec.go1
-rw-r--r--runsc/config/config.go38
-rw-r--r--runsc/config/flags.go8
-rw-r--r--runsc/container/BUILD2
-rw-r--r--runsc/container/container.go141
-rw-r--r--runsc/container/container_test.go47
-rw-r--r--runsc/container/multi_container_test.go1
-rw-r--r--runsc/fsgofer/BUILD16
-rw-r--r--runsc/fsgofer/fsgofer_test.go50
-rw-r--r--runsc/fsgofer/fsgofer_unsafe.go30
-rw-r--r--runsc/fsgofer/lisafs.go1084
-rw-r--r--runsc/fsgofer/lisafs_test.go56
-rw-r--r--runsc/sandbox/BUILD10
-rw-r--r--runsc/sandbox/memory.go70
-rw-r--r--runsc/sandbox/memory_test.go91
-rw-r--r--runsc/sandbox/sandbox.go69
-rw-r--r--runsc/specutils/specutils.go4
-rw-r--r--test/benchmarks/tcp/tcp_proxy.go8
-rw-r--r--test/e2e/integration_test.go44
-rw-r--r--test/packetimpact/runner/dut.go26
-rw-r--r--test/packetimpact/tests/tcp_listen_backlog_test.go310
-rw-r--r--test/runner/main.go8
-rw-r--r--test/runtimes/runner/lib/BUILD6
-rw-r--r--test/runtimes/runner/lib/go_test_dependency_go118.go27
-rw-r--r--test/runtimes/runner/lib/go_test_dependency_not_go118.go25
-rw-r--r--test/runtimes/runner/lib/lib.go22
-rw-r--r--test/syscalls/BUILD9
-rw-r--r--test/syscalls/linux/BUILD53
-rw-r--r--test/syscalls/linux/cgroup.cc81
-rw-r--r--test/syscalls/linux/epoll.cc36
-rw-r--r--test/syscalls/linux/ip_socket_test_util.cc97
-rw-r--r--test/syscalls/linux/ip_socket_test_util.h19
-rw-r--r--test/syscalls/linux/mmap.cc41
-rw-r--r--test/syscalls/linux/mq.cc4
-rw-r--r--test/syscalls/linux/packet_socket.cc715
-rw-r--r--test/syscalls/linux/packet_socket_dgram.cc550
-rw-r--r--test/syscalls/linux/ping_socket.cc6
-rw-r--r--test/syscalls/linux/proc.cc49
-rw-r--r--test/syscalls/linux/proc_isolated.cc100
-rw-r--r--test/syscalls/linux/raw_socket.cc129
-rw-r--r--test/syscalls/linux/socket_ip_udp_generic.cc2
-rw-r--r--test/syscalls/linux/socket_ip_udp_unbound_external_networking.cc59
-rw-r--r--test/syscalls/linux/socket_ip_udp_unbound_external_networking.h18
-rw-r--r--test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound_loopback.cc10
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc220
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h18
-rw-r--r--test/syscalls/linux/socket_ipv6_udp_unbound.cc80
-rw-r--r--test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc2
-rw-r--r--test/syscalls/linux/tcp_socket.cc125
-rw-r--r--test/util/BUILD12
-rw-r--r--test/util/capability_util.h2
-rw-r--r--test/util/cgroup_util.cc3
-rw-r--r--test/util/socket_util.h12
-rw-r--r--tools/bazel.mk4
-rw-r--r--tools/bigquery/bigquery.go131
-rw-r--r--tools/checklocks/analysis.go65
-rw-r--r--tools/checklocks/annotations.go17
-rw-r--r--tools/checklocks/facts.go71
-rw-r--r--tools/checklocks/state.go101
-rw-r--r--tools/checklocks/test/BUILD1
-rw-r--r--tools/checklocks/test/basics.go6
-rw-r--r--tools/checklocks/test/rwmutex.go52
-rw-r--r--tools/go_fieldenum/main.go14
-rw-r--r--tools/nogo/build.go5
-rw-r--r--tools/nogo/check/main.go8
-rw-r--r--tools/nogo/defs.bzl4
-rw-r--r--tools/nogo/nogo.go46
350 files changed, 21197 insertions, 4938 deletions
diff --git a/README.md b/README.md
index aa8b8c068..3d380350f 100644
--- a/README.md
+++ b/README.md
@@ -61,8 +61,9 @@ Make sure the following dependencies are installed:
Build and install the `runsc` binary:
```sh
-make runsc
-sudo cp ./bazel-bin/runsc/linux_amd64_pure_stripped/runsc /usr/local/bin
+mkdir -p bin
+make copy TARGETS=runsc DESTINATION=bin/
+sudo cp ./bin/runsc /usr/local/bin
```
### Testing
diff --git a/WORKSPACE b/WORKSPACE
index b00c09682..6dc647292 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -42,10 +42,10 @@ http_archive(
# binaries of symbols, which we don't want.
"//tools:rules_go_symbols.patch",
],
- sha256 = "69de5c704a05ff37862f7e0f5534d4f479418afc21806c887db544a316f3cb6b",
+ sha256 = "8e968b5fcea1d2d64071872b12737bbb5514524ee5f0a4f54f5920266c261acb",
urls = [
- "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.27.0/rules_go-v0.27.0.tar.gz",
- "https://github.com/bazelbuild/rules_go/releases/download/v0.27.0/rules_go-v0.27.0.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.28.0/rules_go-v0.28.0.zip",
+ "https://github.com/bazelbuild/rules_go/releases/download/v0.28.0/rules_go-v0.28.0.zip",
],
)
@@ -69,7 +69,7 @@ load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_depe
go_rules_dependencies()
-go_register_toolchains(go_version = "1.16.2")
+go_register_toolchains(go_version = "1.16.8")
load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies", "go_repository")
@@ -113,11 +113,11 @@ cc_crosstool(name = "crosstool")
# Load protobuf dependencies.
http_archive(
name = "rules_proto",
- sha256 = "2a20fd8af3cad3fbab9fd3aec4a137621e0c31f858af213a7ae0f997723fc4a9",
- strip_prefix = "rules_proto-a0761ed101b939e19d83b2da5f59034bffc19c12",
+ sha256 = "66bfdf8782796239d3875d37e7de19b1d94301e8972b3cbd2446b332429b4df1",
+ strip_prefix = "rules_proto-4.0.0",
urls = [
- "https://mirror.bazel.build/github.com/bazelbuild/rules_proto/archive/a0761ed101b939e19d83b2da5f59034bffc19c12.tar.gz",
- "https://github.com/bazelbuild/rules_proto/archive/a0761ed101b939e19d83b2da5f59034bffc19c12.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/rules_proto/archive/refs/tags/4.0.0.tar.gz",
+ "https://github.com/bazelbuild/rules_proto/archive/refs/tags/4.0.0.tar.gz",
],
)
@@ -1256,7 +1256,7 @@ http_archive(
strip_prefix = "abseil-cpp-278e0a071885a22dcd2fd1b5576cc44757299343",
urls = [
"https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/278e0a071885a22dcd2fd1b5576cc44757299343.tar.gz",
- "https://github.com/abseil/abseil-cpp/archive/278e0a071885a22dcd2fd1b5576cc44757299343.tar.gz"
+ "https://github.com/abseil/abseil-cpp/archive/278e0a071885a22dcd2fd1b5576cc44757299343.tar.gz",
],
)
@@ -1278,7 +1278,6 @@ load("@com_github_grpc_grpc//bazel:grpc_extra_deps.bzl", "grpc_extra_deps")
grpc_extra_deps()
-
http_archive(
name = "com_google_googletest",
sha256 = "0a10bea96d8670e5eef948d79d824162b1577bb7889539e49ec786bfc3e48912",
@@ -1305,7 +1304,9 @@ http_archive(
strip_prefix = "protobuf-3.17.3",
urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.17.3.zip"],
)
+
load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps")
+
protobuf_deps()
# Schemas for testing.
@@ -1635,7 +1636,7 @@ go_repository(
importpath = "honnef.co/go/tools",
sum = "h1:Tyybiul3hjaq0dkv+kcf5/MPTfo+ZBiEWrkhgxMPH54=",
version = "v0.3.0-0.dev.0.20210801021341-453cb28c0b15",
- )
+)
go_repository(
name = "com_github_burntsushi_toml",
diff --git a/g3doc/README.md b/g3doc/README.md
index 22bfb15f7..dc4179037 100644
--- a/g3doc/README.md
+++ b/g3doc/README.md
@@ -139,8 +139,8 @@ pipes, etc) are sent to the Gofer, described below.
The Gofer is a standard host process which is started with each container and
communicates with the Sentry via the [9P protocol][9p] over a socket or shared
memory channel. The Sentry process is started in a restricted seccomp container
-without access to file system resources. The Gofer mediates all access to the
-these resources, providing an additional level of isolation.
+without access to file system resources. The Gofer mediates all access to these
+resources, providing an additional level of isolation.
### Application {#application}
diff --git a/g3doc/user_guide/compatibility.md b/g3doc/user_guide/compatibility.md
index 76e879a01..ef50c0147 100644
--- a/g3doc/user_guide/compatibility.md
+++ b/g3doc/user_guide/compatibility.md
@@ -42,6 +42,9 @@ Most common utilities work. Note that:
* Some tools, such as `tcpdump` and old versions of `ping`, require explicitly
enabling raw sockets via the unsafe `--net-raw` runsc flag.
+ * In case of tcpdump the following invocations will work
+ * tcpdump -i any
+ * tcpdump -i \<device-name\> -p (-p disables promiscuous mode)
* Different Docker images can behave differently. For example, Alpine Linux
and Ubuntu have different `ip` binaries.
@@ -82,7 +85,7 @@ Most common utilities work. Note that:
| sshd | Partially working. Job control [in progress](https://gvisor.dev/issue/154). |
| strace | Working. |
| tar | Working. |
-| tcpdump | Working. [Promiscuous mode in progress](https://gvisor.dev/issue/3333). |
+| tcpdump | Working [only with libpcap versions < 1.10](https://github.com/google/gvisor/issues/6699), [Promiscuous mode in progress](https://gvisor.dev/issue/3333). |
| top | Working. |
| uptime | Working. |
| vim | Working. |
diff --git a/g3doc/user_guide/install.md b/g3doc/user_guide/install.md
index 85ba6a161..73e953905 100644
--- a/g3doc/user_guide/install.md
+++ b/g3doc/user_guide/install.md
@@ -51,18 +51,17 @@ sudo apt-get install -y \
apt-transport-https \
ca-certificates \
curl \
- gnupg-agent \
- software-properties-common
+ gnupg
```
Next, configure the key used to sign archives and the repository.
NOTE: The key was updated on 2021-07-13 to replace the expired key. If you get
-errors about the key being expired, run the `apt-key add` command below again.
+errors about the key being expired, run the `curl` command below again.
```bash
-curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add -
-sudo add-apt-repository "deb [arch=amd64,arm64] https://storage.googleapis.com/gvisor/releases release main"
+curl -fsSL https://gvisor.dev/archive.key | sudo gpg --dearmor -o /usr/share/keyrings/gvisor-archive-keyring.gpg
+echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/gvisor-archive-keyring.gpg] https://storage.googleapis.com/gvisor/releases release main" | sudo tee /etc/apt/sources.list.d/gvisor.list > /dev/null
```
Now the runsc package can be installed:
diff --git a/images/benchmarks/node/package-lock.json b/images/benchmarks/node/package-lock.json
index 580e68aa5..9f59a3a71 100644
--- a/images/benchmarks/node/package-lock.json
+++ b/images/benchmarks/node/package-lock.json
@@ -1,63 +1,687 @@
{
"name": "nodedum",
"version": "1.0.0",
- "lockfileVersion": 1,
+ "lockfileVersion": 2,
"requires": true,
- "dependencies": {
- "accepts": {
- "version": "1.3.5",
- "resolved": "https://registry.npmjs.org/accepts/-/accepts-1.3.5.tgz",
- "integrity": "sha1-63d99gEXI6OxTopywIBcjoZ0a9I=",
- "requires": {
- "mime-types": "~2.1.18",
- "negotiator": "0.6.1"
+ "packages": {
+ "": {
+ "name": "nodedum",
+ "version": "1.0.0",
+ "license": "ISC",
+ "dependencies": {
+ "express": "^4.16.4",
+ "hbs": "^4.0.4",
+ "redis": "^3.1.2",
+ "redis-commands": "^1.2.0",
+ "redis-parser": "^2.6.0",
+ "secure-random-string": "^1.1.0"
}
},
- "array-flatten": {
+ "node_modules/accepts": {
+ "version": "1.3.7",
+ "resolved": "https://registry.npmjs.org/accepts/-/accepts-1.3.7.tgz",
+ "integrity": "sha512-Il80Qs2WjYlJIBNzNkK6KYqlVMTbZLXgHx2oT0pU/fjRHyEp+PEfEPY0R3WCwAGVOtauxh1hOxNgIf5bv7dQpA==",
+ "dependencies": {
+ "mime-types": "~2.1.24",
+ "negotiator": "0.6.2"
+ },
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/array-flatten": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-1.1.1.tgz",
"integrity": "sha1-ml9pkFGx5wczKPKgCJaLZOopVdI="
},
- "async": {
+ "node_modules/body-parser": {
+ "version": "1.19.0",
+ "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.19.0.tgz",
+ "integrity": "sha512-dhEPs72UPbDnAQJ9ZKMNTP6ptJaionhP5cBb541nXPlW60Jepo9RV/a4fX4XWW9CuFNK22krhrj1+rgzifNCsw==",
+ "dependencies": {
+ "bytes": "3.1.0",
+ "content-type": "~1.0.4",
+ "debug": "2.6.9",
+ "depd": "~1.1.2",
+ "http-errors": "1.7.2",
+ "iconv-lite": "0.4.24",
+ "on-finished": "~2.3.0",
+ "qs": "6.7.0",
+ "raw-body": "2.4.0",
+ "type-is": "~1.6.17"
+ },
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
+ "node_modules/bytes": {
+ "version": "3.1.0",
+ "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.1.0.tgz",
+ "integrity": "sha512-zauLjrfCG+xvoyaqLoV8bLVXXNGC4JqlxFCutSDWA6fJrTo2ZuvLYTqZ7aHBLZSMOopbzwv8f+wZcVzfVTI2Dg==",
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
+ "node_modules/content-disposition": {
+ "version": "0.5.3",
+ "resolved": "https://registry.npmjs.org/content-disposition/-/content-disposition-0.5.3.tgz",
+ "integrity": "sha512-ExO0774ikEObIAEV9kDo50o+79VCUdEB6n6lzKgGwupcVeRlhrj3qGAfwq8G6uBJjkqLrhT0qEYFcWng8z1z0g==",
+ "dependencies": {
+ "safe-buffer": "5.1.2"
+ },
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/content-type": {
+ "version": "1.0.4",
+ "resolved": "https://registry.npmjs.org/content-type/-/content-type-1.0.4.tgz",
+ "integrity": "sha512-hIP3EEPs8tB9AT1L+NUqtwOAps4mk2Zob89MWXMHjHWg9milF/j4osnnQLXBCBFBk/tvIG/tUc9mOUJiPBhPXA==",
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/cookie": {
+ "version": "0.4.0",
+ "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.4.0.tgz",
+ "integrity": "sha512-+Hp8fLp57wnUSt0tY0tHEXh4voZRDnoIrZPqlo3DPiI4y9lwg/jqx+1Om94/W6ZaPDOUbnjOt/99w66zk+l1Xg==",
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/cookie-signature": {
+ "version": "1.0.6",
+ "resolved": "https://registry.npmjs.org/cookie-signature/-/cookie-signature-1.0.6.tgz",
+ "integrity": "sha1-4wOogrNCzD7oylE6eZmXNNqzriw="
+ },
+ "node_modules/debug": {
+ "version": "2.6.9",
+ "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz",
+ "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==",
+ "dependencies": {
+ "ms": "2.0.0"
+ }
+ },
+ "node_modules/denque": {
+ "version": "1.5.1",
+ "resolved": "https://registry.npmjs.org/denque/-/denque-1.5.1.tgz",
+ "integrity": "sha512-XwE+iZ4D6ZUB7mfYRMb5wByE8L74HCn30FBN7sWnXksWc1LO1bPDl67pBR9o/kC4z/xSNAwkMYcGgqDV3BE3Hw==",
+ "engines": {
+ "node": ">=0.10"
+ }
+ },
+ "node_modules/depd": {
+ "version": "1.1.2",
+ "resolved": "https://registry.npmjs.org/depd/-/depd-1.1.2.tgz",
+ "integrity": "sha1-m81S4UwJd2PnSbJ0xDRu0uVgtak=",
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/destroy": {
+ "version": "1.0.4",
+ "resolved": "https://registry.npmjs.org/destroy/-/destroy-1.0.4.tgz",
+ "integrity": "sha1-l4hXRCxEdJ5CBmE+N5RiBYJqvYA="
+ },
+ "node_modules/ee-first": {
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/ee-first/-/ee-first-1.1.1.tgz",
+ "integrity": "sha1-WQxhFWsK4vTwJVcyoViyZrxWsh0="
+ },
+ "node_modules/encodeurl": {
+ "version": "1.0.2",
+ "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz",
+ "integrity": "sha1-rT/0yG7C0CkyL1oCw6mmBslbP1k=",
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
+ "node_modules/escape-html": {
+ "version": "1.0.3",
+ "resolved": "https://registry.npmjs.org/escape-html/-/escape-html-1.0.3.tgz",
+ "integrity": "sha1-Aljq5NPQwJdN4cFpGI7wBR0dGYg="
+ },
+ "node_modules/etag": {
+ "version": "1.8.1",
+ "resolved": "https://registry.npmjs.org/etag/-/etag-1.8.1.tgz",
+ "integrity": "sha1-Qa4u62XvpiJorr/qg6x9eSmbCIc=",
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/express": {
+ "version": "4.17.1",
+ "resolved": "https://registry.npmjs.org/express/-/express-4.17.1.tgz",
+ "integrity": "sha512-mHJ9O79RqluphRrcw2X/GTh3k9tVv8YcoyY4Kkh4WDMUYKRZUq0h1o0w2rrrxBqM7VoeUVqgb27xlEMXTnYt4g==",
+ "dependencies": {
+ "accepts": "~1.3.7",
+ "array-flatten": "1.1.1",
+ "body-parser": "1.19.0",
+ "content-disposition": "0.5.3",
+ "content-type": "~1.0.4",
+ "cookie": "0.4.0",
+ "cookie-signature": "1.0.6",
+ "debug": "2.6.9",
+ "depd": "~1.1.2",
+ "encodeurl": "~1.0.2",
+ "escape-html": "~1.0.3",
+ "etag": "~1.8.1",
+ "finalhandler": "~1.1.2",
+ "fresh": "0.5.2",
+ "merge-descriptors": "1.0.1",
+ "methods": "~1.1.2",
+ "on-finished": "~2.3.0",
+ "parseurl": "~1.3.3",
+ "path-to-regexp": "0.1.7",
+ "proxy-addr": "~2.0.5",
+ "qs": "6.7.0",
+ "range-parser": "~1.2.1",
+ "safe-buffer": "5.1.2",
+ "send": "0.17.1",
+ "serve-static": "1.14.1",
+ "setprototypeof": "1.1.1",
+ "statuses": "~1.5.0",
+ "type-is": "~1.6.18",
+ "utils-merge": "1.0.1",
+ "vary": "~1.1.2"
+ },
+ "engines": {
+ "node": ">= 0.10.0"
+ }
+ },
+ "node_modules/finalhandler": {
+ "version": "1.1.2",
+ "resolved": "https://registry.npmjs.org/finalhandler/-/finalhandler-1.1.2.tgz",
+ "integrity": "sha512-aAWcW57uxVNrQZqFXjITpW3sIUQmHGG3qSb9mUah9MgMC4NeWhNOlNjXEYq3HjRAvL6arUviZGGJsBg6z0zsWA==",
+ "dependencies": {
+ "debug": "2.6.9",
+ "encodeurl": "~1.0.2",
+ "escape-html": "~1.0.3",
+ "on-finished": "~2.3.0",
+ "parseurl": "~1.3.3",
+ "statuses": "~1.5.0",
+ "unpipe": "~1.0.0"
+ },
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
+ "node_modules/foreachasync": {
+ "version": "3.0.0",
+ "resolved": "https://registry.npmjs.org/foreachasync/-/foreachasync-3.0.0.tgz",
+ "integrity": "sha1-VQKYfchxS+M5IJfzLgBxyd7gfPY="
+ },
+ "node_modules/forwarded": {
+ "version": "0.2.0",
+ "resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.2.0.tgz",
+ "integrity": "sha512-buRG0fpBtRHSTCOASe6hD258tEubFoRLb4ZNA6NxMVHNw2gOcwHo9wyablzMzOA5z9xA9L1KNjk/Nt6MT9aYow==",
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/fresh": {
+ "version": "0.5.2",
+ "resolved": "https://registry.npmjs.org/fresh/-/fresh-0.5.2.tgz",
+ "integrity": "sha1-PYyt2Q2XZWn6g1qx+OSyOhBWBac=",
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/handlebars": {
+ "version": "4.7.7",
+ "resolved": "https://registry.npmjs.org/handlebars/-/handlebars-4.7.7.tgz",
+ "integrity": "sha512-aAcXm5OAfE/8IXkcZvCepKU3VzW1/39Fb5ZuqMtgI/hT8X2YgoMvBY5dLhq/cpOvw7Lk1nK/UF71aLG/ZnVYRA==",
+ "dependencies": {
+ "minimist": "^1.2.5",
+ "neo-async": "^2.6.0",
+ "source-map": "^0.6.1",
+ "uglify-js": "^3.1.4",
+ "wordwrap": "^1.0.0"
+ },
+ "bin": {
+ "handlebars": "bin/handlebars"
+ },
+ "engines": {
+ "node": ">=0.4.7"
+ },
+ "optionalDependencies": {
+ "uglify-js": "^3.1.4"
+ }
+ },
+ "node_modules/hbs": {
+ "version": "4.1.2",
+ "resolved": "https://registry.npmjs.org/hbs/-/hbs-4.1.2.tgz",
+ "integrity": "sha512-WfBnQbozbdiTLjJu6P6Wturgvy0FN8xtRmIjmP0ebX9OGQrt+2S6UC7xX0IebHTCS1sXe20zfTzQ7yhjrEvrfQ==",
+ "dependencies": {
+ "handlebars": "4.7.7",
+ "walk": "2.3.14"
+ },
+ "engines": {
+ "node": ">= 0.8",
+ "npm": "1.2.8000 || >= 1.4.16"
+ }
+ },
+ "node_modules/http-errors": {
+ "version": "1.7.2",
+ "resolved": "https://registry.npmjs.org/http-errors/-/http-errors-1.7.2.tgz",
+ "integrity": "sha512-uUQBt3H/cSIVfch6i1EuPNy/YsRSOUBXTVfZ+yR7Zjez3qjBz6i9+i4zjNaoqcoFVI4lQJ5plg63TvGfRSDCRg==",
+ "dependencies": {
+ "depd": "~1.1.2",
+ "inherits": "2.0.3",
+ "setprototypeof": "1.1.1",
+ "statuses": ">= 1.5.0 < 2",
+ "toidentifier": "1.0.0"
+ },
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/iconv-lite": {
+ "version": "0.4.24",
+ "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.4.24.tgz",
+ "integrity": "sha512-v3MXnZAcvnywkTUEZomIActle7RXXeedOR31wwl7VlyoXO4Qi9arvSenNQWne1TcRwhCL1HwLI21bEqdpj8/rA==",
+ "dependencies": {
+ "safer-buffer": ">= 2.1.2 < 3"
+ },
+ "engines": {
+ "node": ">=0.10.0"
+ }
+ },
+ "node_modules/inherits": {
+ "version": "2.0.3",
+ "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.3.tgz",
+ "integrity": "sha1-Yzwsg+PaQqUC9SRmAiSA9CCCYd4="
+ },
+ "node_modules/ipaddr.js": {
+ "version": "1.9.1",
+ "resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.9.1.tgz",
+ "integrity": "sha512-0KI/607xoxSToH7GjN1FfSbLoU0+btTicjsQSWQlh/hZykN8KpmMf7uYwPW3R+akZ6R/w18ZlXSHBYXiYUPO3g==",
+ "engines": {
+ "node": ">= 0.10"
+ }
+ },
+ "node_modules/media-typer": {
+ "version": "0.3.0",
+ "resolved": "https://registry.npmjs.org/media-typer/-/media-typer-0.3.0.tgz",
+ "integrity": "sha1-hxDXrwqmJvj/+hzgAWhUUmMlV0g=",
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/merge-descriptors": {
+ "version": "1.0.1",
+ "resolved": "https://registry.npmjs.org/merge-descriptors/-/merge-descriptors-1.0.1.tgz",
+ "integrity": "sha1-sAqqVW3YtEVoFQ7J0blT8/kMu2E="
+ },
+ "node_modules/methods": {
+ "version": "1.1.2",
+ "resolved": "https://registry.npmjs.org/methods/-/methods-1.1.2.tgz",
+ "integrity": "sha1-VSmk1nZUE07cxSZmVoNbD4Ua/O4=",
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/mime": {
+ "version": "1.6.0",
+ "resolved": "https://registry.npmjs.org/mime/-/mime-1.6.0.tgz",
+ "integrity": "sha512-x0Vn8spI+wuJ1O6S7gnbaQg8Pxh4NNHb7KSINmEWKiPE4RKOplvijn+NkmYmmRgP68mc70j2EbeTFRsrswaQeg==",
+ "bin": {
+ "mime": "cli.js"
+ },
+ "engines": {
+ "node": ">=4"
+ }
+ },
+ "node_modules/mime-db": {
+ "version": "1.49.0",
+ "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.49.0.tgz",
+ "integrity": "sha512-CIc8j9URtOVApSFCQIF+VBkX1RwXp/oMMOrqdyXSBXq5RWNEsRfyj1kiRnQgmNXmHxPoFIxOroKA3zcU9P+nAA==",
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/mime-types": {
+ "version": "2.1.32",
+ "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.32.tgz",
+ "integrity": "sha512-hJGaVS4G4c9TSMYh2n6SQAGrC4RnfU+daP8G7cSCmaqNjiOoUY0VHCMS42pxnQmVF1GWwFhbHWn3RIxCqTmZ9A==",
+ "dependencies": {
+ "mime-db": "1.49.0"
+ },
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/minimist": {
+ "version": "1.2.5",
+ "resolved": "https://registry.npmjs.org/minimist/-/minimist-1.2.5.tgz",
+ "integrity": "sha512-FM9nNUYrRBAELZQT3xeZQ7fmMOBg6nWNmJKTcgsJeaLstP/UODVpGsr5OhXhhXg6f+qtJ8uiZ+PUxkDWcgIXLw=="
+ },
+ "node_modules/ms": {
+ "version": "2.0.0",
+ "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz",
+ "integrity": "sha1-VgiurfwAvmwpAd9fmGF4jeDVl8g="
+ },
+ "node_modules/negotiator": {
+ "version": "0.6.2",
+ "resolved": "https://registry.npmjs.org/negotiator/-/negotiator-0.6.2.tgz",
+ "integrity": "sha512-hZXc7K2e+PgeI1eDBe/10Ard4ekbfrrqG8Ep+8Jmf4JID2bNg7NvCPOZN+kfF574pFQI7mum2AUqDidoKqcTOw==",
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/neo-async": {
"version": "2.6.2",
- "resolved": "https://registry.npmjs.org/async/-/async-2.6.2.tgz",
- "integrity": "sha512-H1qVYh1MYhEEFLsP97cVKqCGo7KfCyTt6uEWqsTBr9SO84oK9Uwbyd/yCW+6rKJLHksBNUVWZDAjfS+Ccx0Bbg==",
+ "resolved": "https://registry.npmjs.org/neo-async/-/neo-async-2.6.2.tgz",
+ "integrity": "sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw=="
+ },
+ "node_modules/on-finished": {
+ "version": "2.3.0",
+ "resolved": "https://registry.npmjs.org/on-finished/-/on-finished-2.3.0.tgz",
+ "integrity": "sha1-IPEzZIGwg811M3mSoWlxqi2QaUc=",
+ "dependencies": {
+ "ee-first": "1.1.1"
+ },
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
+ "node_modules/parseurl": {
+ "version": "1.3.3",
+ "resolved": "https://registry.npmjs.org/parseurl/-/parseurl-1.3.3.tgz",
+ "integrity": "sha512-CiyeOxFT/JZyN5m0z9PfXw4SCBJ6Sygz1Dpl0wqjlhDEGGBP1GnsUVEL0p63hoG1fcj3fHynXi9NYO4nWOL+qQ==",
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
+ "node_modules/path-to-regexp": {
+ "version": "0.1.7",
+ "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.7.tgz",
+ "integrity": "sha1-32BBeABfUi8V60SQ5yR6G/qmf4w="
+ },
+ "node_modules/proxy-addr": {
+ "version": "2.0.7",
+ "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.7.tgz",
+ "integrity": "sha512-llQsMLSUDUPT44jdrU/O37qlnifitDP+ZwrmmZcoSKyLKvtZxpyV0n2/bD/N4tBAAZ/gJEdZU7KMraoK1+XYAg==",
+ "dependencies": {
+ "forwarded": "0.2.0",
+ "ipaddr.js": "1.9.1"
+ },
+ "engines": {
+ "node": ">= 0.10"
+ }
+ },
+ "node_modules/qs": {
+ "version": "6.7.0",
+ "resolved": "https://registry.npmjs.org/qs/-/qs-6.7.0.tgz",
+ "integrity": "sha512-VCdBRNFTX1fyE7Nb6FYoURo/SPe62QCaAyzJvUjwRaIsc+NePBEniHlvxFmmX56+HZphIGtV0XeCirBtpDrTyQ==",
+ "engines": {
+ "node": ">=0.6"
+ }
+ },
+ "node_modules/range-parser": {
+ "version": "1.2.1",
+ "resolved": "https://registry.npmjs.org/range-parser/-/range-parser-1.2.1.tgz",
+ "integrity": "sha512-Hrgsx+orqoygnmhFbKaHE6c296J+HTAQXoxEF6gNupROmmGJRoyzfG3ccAveqCBrwr/2yxQ5BVd/GTl5agOwSg==",
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/raw-body": {
+ "version": "2.4.0",
+ "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.4.0.tgz",
+ "integrity": "sha512-4Oz8DUIwdvoa5qMJelxipzi/iJIi40O5cGV1wNYp5hvZP8ZN0T+jiNkL0QepXs+EsQ9XJ8ipEDoiH70ySUJP3Q==",
+ "dependencies": {
+ "bytes": "3.1.0",
+ "http-errors": "1.7.2",
+ "iconv-lite": "0.4.24",
+ "unpipe": "1.0.0"
+ },
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
+ "node_modules/redis": {
+ "version": "3.1.2",
+ "resolved": "https://registry.npmjs.org/redis/-/redis-3.1.2.tgz",
+ "integrity": "sha512-grn5KoZLr/qrRQVwoSkmzdbw6pwF+/rwODtrOr6vuBRiR/f3rjSTGupbF90Zpqm2oenix8Do6RV7pYEkGwlKkw==",
+ "dependencies": {
+ "denque": "^1.5.0",
+ "redis-commands": "^1.7.0",
+ "redis-errors": "^1.2.0",
+ "redis-parser": "^3.0.0"
+ },
+ "engines": {
+ "node": ">=10"
+ },
+ "funding": {
+ "type": "opencollective",
+ "url": "https://opencollective.com/node-redis"
+ }
+ },
+ "node_modules/redis-commands": {
+ "version": "1.7.0",
+ "resolved": "https://registry.npmjs.org/redis-commands/-/redis-commands-1.7.0.tgz",
+ "integrity": "sha512-nJWqw3bTFy21hX/CPKHth6sfhZbdiHP6bTawSgQBlKOVRG7EZkfHbbHwQJnrE4vsQf0CMNE+3gJ4Fmm16vdVlQ=="
+ },
+ "node_modules/redis-errors": {
+ "version": "1.2.0",
+ "resolved": "https://registry.npmjs.org/redis-errors/-/redis-errors-1.2.0.tgz",
+ "integrity": "sha1-62LSrbFeTq9GEMBK/hUpOEJQq60=",
+ "engines": {
+ "node": ">=4"
+ }
+ },
+ "node_modules/redis-parser": {
+ "version": "2.6.0",
+ "resolved": "https://registry.npmjs.org/redis-parser/-/redis-parser-2.6.0.tgz",
+ "integrity": "sha1-Uu0J2srBCPGmMcB+m2mUHnoZUEs=",
+ "engines": {
+ "node": ">=0.10.0"
+ }
+ },
+ "node_modules/redis/node_modules/redis-parser": {
+ "version": "3.0.0",
+ "resolved": "https://registry.npmjs.org/redis-parser/-/redis-parser-3.0.0.tgz",
+ "integrity": "sha1-tm2CjNyv5rS4pCin3vTGvKwxyLQ=",
+ "dependencies": {
+ "redis-errors": "^1.0.0"
+ },
+ "engines": {
+ "node": ">=4"
+ }
+ },
+ "node_modules/safe-buffer": {
+ "version": "5.1.2",
+ "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.1.2.tgz",
+ "integrity": "sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g=="
+ },
+ "node_modules/safer-buffer": {
+ "version": "2.1.2",
+ "resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz",
+ "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg=="
+ },
+ "node_modules/secure-random-string": {
+ "version": "1.1.3",
+ "resolved": "https://registry.npmjs.org/secure-random-string/-/secure-random-string-1.1.3.tgz",
+ "integrity": "sha512-298HxkJJp5mjpPhxDsN26S/2JmMaUIrQ4PxDI/F4fXKRBTOKendQ5i6JCkc+a8F8koLh0vdfwSCw8+RJkY7N6A=="
+ },
+ "node_modules/send": {
+ "version": "0.17.1",
+ "resolved": "https://registry.npmjs.org/send/-/send-0.17.1.tgz",
+ "integrity": "sha512-BsVKsiGcQMFwT8UxypobUKyv7irCNRHk1T0G680vk88yf6LBByGcZJOTJCrTP2xVN6yI+XjPJcNuE3V4fT9sAg==",
+ "dependencies": {
+ "debug": "2.6.9",
+ "depd": "~1.1.2",
+ "destroy": "~1.0.4",
+ "encodeurl": "~1.0.2",
+ "escape-html": "~1.0.3",
+ "etag": "~1.8.1",
+ "fresh": "0.5.2",
+ "http-errors": "~1.7.2",
+ "mime": "1.6.0",
+ "ms": "2.1.1",
+ "on-finished": "~2.3.0",
+ "range-parser": "~1.2.1",
+ "statuses": "~1.5.0"
+ },
+ "engines": {
+ "node": ">= 0.8.0"
+ }
+ },
+ "node_modules/send/node_modules/ms": {
+ "version": "2.1.1",
+ "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.1.tgz",
+ "integrity": "sha512-tgp+dl5cGk28utYktBsrFqA7HKgrhgPsg6Z/EfhWI4gl1Hwq8B/GmY/0oXZ6nF8hDVesS/FpnYaD/kOWhYQvyg=="
+ },
+ "node_modules/serve-static": {
+ "version": "1.14.1",
+ "resolved": "https://registry.npmjs.org/serve-static/-/serve-static-1.14.1.tgz",
+ "integrity": "sha512-JMrvUwE54emCYWlTI+hGrGv5I8dEwmco/00EvkzIIsR7MqrHonbD9pO2MOfFnpFntl7ecpZs+3mW+XbQZu9QCg==",
+ "dependencies": {
+ "encodeurl": "~1.0.2",
+ "escape-html": "~1.0.3",
+ "parseurl": "~1.3.3",
+ "send": "0.17.1"
+ },
+ "engines": {
+ "node": ">= 0.8.0"
+ }
+ },
+ "node_modules/setprototypeof": {
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.1.1.tgz",
+ "integrity": "sha512-JvdAWfbXeIGaZ9cILp38HntZSFSo3mWg6xGcJJsd+d4aRMOqauag1C63dJfDw7OaMYwEbHMOxEZ1lqVRYP2OAw=="
+ },
+ "node_modules/source-map": {
+ "version": "0.6.1",
+ "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz",
+ "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==",
+ "engines": {
+ "node": ">=0.10.0"
+ }
+ },
+ "node_modules/statuses": {
+ "version": "1.5.0",
+ "resolved": "https://registry.npmjs.org/statuses/-/statuses-1.5.0.tgz",
+ "integrity": "sha1-Fhx9rBd2Wf2YEfQ3cfqZOBR4Yow=",
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/toidentifier": {
+ "version": "1.0.0",
+ "resolved": "https://registry.npmjs.org/toidentifier/-/toidentifier-1.0.0.tgz",
+ "integrity": "sha512-yaOH/Pk/VEhBWWTlhI+qXxDFXlejDGcQipMlyxda9nthulaxLZUNcUqFxokp0vcYnvteJln5FNQDRrxj3YcbVw==",
+ "engines": {
+ "node": ">=0.6"
+ }
+ },
+ "node_modules/type-is": {
+ "version": "1.6.18",
+ "resolved": "https://registry.npmjs.org/type-is/-/type-is-1.6.18.tgz",
+ "integrity": "sha512-TkRKr9sUTxEH8MdfuCSP7VizJyzRNMjj2J2do2Jr3Kym598JVdEksuzPQCnlFPW4ky9Q+iA+ma9BGm06XQBy8g==",
+ "dependencies": {
+ "media-typer": "0.3.0",
+ "mime-types": "~2.1.24"
+ },
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/uglify-js": {
+ "version": "3.14.2",
+ "resolved": "https://registry.npmjs.org/uglify-js/-/uglify-js-3.14.2.tgz",
+ "integrity": "sha512-rtPMlmcO4agTUfz10CbgJ1k6UAoXM2gWb3GoMPPZB/+/Ackf8lNWk11K4rYi2D0apgoFRLtQOZhb+/iGNJq26A==",
+ "optional": true,
+ "bin": {
+ "uglifyjs": "bin/uglifyjs"
+ },
+ "engines": {
+ "node": ">=0.8.0"
+ }
+ },
+ "node_modules/unpipe": {
+ "version": "1.0.0",
+ "resolved": "https://registry.npmjs.org/unpipe/-/unpipe-1.0.0.tgz",
+ "integrity": "sha1-sr9O6FFKrmFltIF4KdIbLvSZBOw=",
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
+ "node_modules/utils-merge": {
+ "version": "1.0.1",
+ "resolved": "https://registry.npmjs.org/utils-merge/-/utils-merge-1.0.1.tgz",
+ "integrity": "sha1-n5VxD1CiZ5R7LMwSR0HBAoQn5xM=",
+ "engines": {
+ "node": ">= 0.4.0"
+ }
+ },
+ "node_modules/vary": {
+ "version": "1.1.2",
+ "resolved": "https://registry.npmjs.org/vary/-/vary-1.1.2.tgz",
+ "integrity": "sha1-IpnwLG3tMNSllhsLn3RSShj2NPw=",
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
+ "node_modules/walk": {
+ "version": "2.3.14",
+ "resolved": "https://registry.npmjs.org/walk/-/walk-2.3.14.tgz",
+ "integrity": "sha512-5skcWAUmySj6hkBdH6B6+3ddMjVQYH5Qy9QGbPmN8kVmLteXk+yVXg+yfk1nbX30EYakahLrr8iPcCxJQSCBeg==",
+ "dependencies": {
+ "foreachasync": "^3.0.0"
+ }
+ },
+ "node_modules/wordwrap": {
+ "version": "1.0.0",
+ "resolved": "https://registry.npmjs.org/wordwrap/-/wordwrap-1.0.0.tgz",
+ "integrity": "sha1-J1hIEIkUVqQXHI0CJkQa3pDLyus="
+ }
+ },
+ "dependencies": {
+ "accepts": {
+ "version": "1.3.7",
+ "resolved": "https://registry.npmjs.org/accepts/-/accepts-1.3.7.tgz",
+ "integrity": "sha512-Il80Qs2WjYlJIBNzNkK6KYqlVMTbZLXgHx2oT0pU/fjRHyEp+PEfEPY0R3WCwAGVOtauxh1hOxNgIf5bv7dQpA==",
"requires": {
- "lodash": "^4.17.11"
+ "mime-types": "~2.1.24",
+ "negotiator": "0.6.2"
}
},
+ "array-flatten": {
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-1.1.1.tgz",
+ "integrity": "sha1-ml9pkFGx5wczKPKgCJaLZOopVdI="
+ },
"body-parser": {
- "version": "1.18.3",
- "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.18.3.tgz",
- "integrity": "sha1-WykhmP/dVTs6DyDe0FkrlWlVyLQ=",
+ "version": "1.19.0",
+ "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.19.0.tgz",
+ "integrity": "sha512-dhEPs72UPbDnAQJ9ZKMNTP6ptJaionhP5cBb541nXPlW60Jepo9RV/a4fX4XWW9CuFNK22krhrj1+rgzifNCsw==",
"requires": {
- "bytes": "3.0.0",
+ "bytes": "3.1.0",
"content-type": "~1.0.4",
"debug": "2.6.9",
"depd": "~1.1.2",
- "http-errors": "~1.6.3",
- "iconv-lite": "0.4.23",
+ "http-errors": "1.7.2",
+ "iconv-lite": "0.4.24",
"on-finished": "~2.3.0",
- "qs": "6.5.2",
- "raw-body": "2.3.3",
- "type-is": "~1.6.16"
+ "qs": "6.7.0",
+ "raw-body": "2.4.0",
+ "type-is": "~1.6.17"
}
},
"bytes": {
- "version": "3.0.0",
- "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.0.0.tgz",
- "integrity": "sha1-0ygVQE1olpn4Wk6k+odV3ROpYEg="
- },
- "commander": {
- "version": "2.20.0",
- "resolved": "https://registry.npmjs.org/commander/-/commander-2.20.0.tgz",
- "integrity": "sha512-7j2y+40w61zy6YC2iRNpUe/NwhNyoXrYpHMrSunaMG64nRnaf96zO/KMQR4OyN/UnE5KLyEBnKHd4aG3rskjpQ==",
- "optional": true
+ "version": "3.1.0",
+ "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.1.0.tgz",
+ "integrity": "sha512-zauLjrfCG+xvoyaqLoV8bLVXXNGC4JqlxFCutSDWA6fJrTo2ZuvLYTqZ7aHBLZSMOopbzwv8f+wZcVzfVTI2Dg=="
},
"content-disposition": {
- "version": "0.5.2",
- "resolved": "https://registry.npmjs.org/content-disposition/-/content-disposition-0.5.2.tgz",
- "integrity": "sha1-DPaLud318r55YcOoUXjLhdunjLQ="
+ "version": "0.5.3",
+ "resolved": "https://registry.npmjs.org/content-disposition/-/content-disposition-0.5.3.tgz",
+ "integrity": "sha512-ExO0774ikEObIAEV9kDo50o+79VCUdEB6n6lzKgGwupcVeRlhrj3qGAfwq8G6uBJjkqLrhT0qEYFcWng8z1z0g==",
+ "requires": {
+ "safe-buffer": "5.1.2"
+ }
},
"content-type": {
"version": "1.0.4",
@@ -65,9 +689,9 @@
"integrity": "sha512-hIP3EEPs8tB9AT1L+NUqtwOAps4mk2Zob89MWXMHjHWg9milF/j4osnnQLXBCBFBk/tvIG/tUc9mOUJiPBhPXA=="
},
"cookie": {
- "version": "0.3.1",
- "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.3.1.tgz",
- "integrity": "sha1-5+Ch+e9DtMi6klxcWpboBtFoc7s="
+ "version": "0.4.0",
+ "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.4.0.tgz",
+ "integrity": "sha512-+Hp8fLp57wnUSt0tY0tHEXh4voZRDnoIrZPqlo3DPiI4y9lwg/jqx+1Om94/W6ZaPDOUbnjOt/99w66zk+l1Xg=="
},
"cookie-signature": {
"version": "1.0.6",
@@ -82,6 +706,11 @@
"ms": "2.0.0"
}
},
+ "denque": {
+ "version": "1.5.1",
+ "resolved": "https://registry.npmjs.org/denque/-/denque-1.5.1.tgz",
+ "integrity": "sha512-XwE+iZ4D6ZUB7mfYRMb5wByE8L74HCn30FBN7sWnXksWc1LO1bPDl67pBR9o/kC4z/xSNAwkMYcGgqDV3BE3Hw=="
+ },
"depd": {
"version": "1.1.2",
"resolved": "https://registry.npmjs.org/depd/-/depd-1.1.2.tgz",
@@ -92,11 +721,6 @@
"resolved": "https://registry.npmjs.org/destroy/-/destroy-1.0.4.tgz",
"integrity": "sha1-l4hXRCxEdJ5CBmE+N5RiBYJqvYA="
},
- "double-ended-queue": {
- "version": "2.1.0-0",
- "resolved": "https://registry.npmjs.org/double-ended-queue/-/double-ended-queue-2.1.0-0.tgz",
- "integrity": "sha1-ED01J/0xUo9AGIEwyEHv3XgmTlw="
- },
"ee-first": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/ee-first/-/ee-first-1.1.1.tgz",
@@ -118,53 +742,53 @@
"integrity": "sha1-Qa4u62XvpiJorr/qg6x9eSmbCIc="
},
"express": {
- "version": "4.16.4",
- "resolved": "https://registry.npmjs.org/express/-/express-4.16.4.tgz",
- "integrity": "sha512-j12Uuyb4FMrd/qQAm6uCHAkPtO8FDTRJZBDd5D2KOL2eLaz1yUNdUB/NOIyq0iU4q4cFarsUCrnFDPBcnksuOg==",
+ "version": "4.17.1",
+ "resolved": "https://registry.npmjs.org/express/-/express-4.17.1.tgz",
+ "integrity": "sha512-mHJ9O79RqluphRrcw2X/GTh3k9tVv8YcoyY4Kkh4WDMUYKRZUq0h1o0w2rrrxBqM7VoeUVqgb27xlEMXTnYt4g==",
"requires": {
- "accepts": "~1.3.5",
+ "accepts": "~1.3.7",
"array-flatten": "1.1.1",
- "body-parser": "1.18.3",
- "content-disposition": "0.5.2",
+ "body-parser": "1.19.0",
+ "content-disposition": "0.5.3",
"content-type": "~1.0.4",
- "cookie": "0.3.1",
+ "cookie": "0.4.0",
"cookie-signature": "1.0.6",
"debug": "2.6.9",
"depd": "~1.1.2",
"encodeurl": "~1.0.2",
"escape-html": "~1.0.3",
"etag": "~1.8.1",
- "finalhandler": "1.1.1",
+ "finalhandler": "~1.1.2",
"fresh": "0.5.2",
"merge-descriptors": "1.0.1",
"methods": "~1.1.2",
"on-finished": "~2.3.0",
- "parseurl": "~1.3.2",
+ "parseurl": "~1.3.3",
"path-to-regexp": "0.1.7",
- "proxy-addr": "~2.0.4",
- "qs": "6.5.2",
- "range-parser": "~1.2.0",
+ "proxy-addr": "~2.0.5",
+ "qs": "6.7.0",
+ "range-parser": "~1.2.1",
"safe-buffer": "5.1.2",
- "send": "0.16.2",
- "serve-static": "1.13.2",
- "setprototypeof": "1.1.0",
- "statuses": "~1.4.0",
- "type-is": "~1.6.16",
+ "send": "0.17.1",
+ "serve-static": "1.14.1",
+ "setprototypeof": "1.1.1",
+ "statuses": "~1.5.0",
+ "type-is": "~1.6.18",
"utils-merge": "1.0.1",
"vary": "~1.1.2"
}
},
"finalhandler": {
- "version": "1.1.1",
- "resolved": "http://registry.npmjs.org/finalhandler/-/finalhandler-1.1.1.tgz",
- "integrity": "sha512-Y1GUDo39ez4aHAw7MysnUD5JzYX+WaIj8I57kO3aEPT1fFRL4sr7mjei97FgnwhAyyzRYmQZaTHb2+9uZ1dPtg==",
+ "version": "1.1.2",
+ "resolved": "https://registry.npmjs.org/finalhandler/-/finalhandler-1.1.2.tgz",
+ "integrity": "sha512-aAWcW57uxVNrQZqFXjITpW3sIUQmHGG3qSb9mUah9MgMC4NeWhNOlNjXEYq3HjRAvL6arUviZGGJsBg6z0zsWA==",
"requires": {
"debug": "2.6.9",
"encodeurl": "~1.0.2",
"escape-html": "~1.0.3",
"on-finished": "~2.3.0",
- "parseurl": "~1.3.2",
- "statuses": "~1.4.0",
+ "parseurl": "~1.3.3",
+ "statuses": "~1.5.0",
"unpipe": "~1.0.0"
}
},
@@ -174,9 +798,9 @@
"integrity": "sha1-VQKYfchxS+M5IJfzLgBxyd7gfPY="
},
"forwarded": {
- "version": "0.1.2",
- "resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.1.2.tgz",
- "integrity": "sha1-mMI9qxF1ZXuMBXPozszZGw/xjIQ="
+ "version": "0.2.0",
+ "resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.2.0.tgz",
+ "integrity": "sha512-buRG0fpBtRHSTCOASe6hD258tEubFoRLb4ZNA6NxMVHNw2gOcwHo9wyablzMzOA5z9xA9L1KNjk/Nt6MT9aYow=="
},
"fresh": {
"version": "0.5.2",
@@ -184,40 +808,42 @@
"integrity": "sha1-PYyt2Q2XZWn6g1qx+OSyOhBWBac="
},
"handlebars": {
- "version": "4.0.14",
- "resolved": "https://registry.npmjs.org/handlebars/-/handlebars-4.0.14.tgz",
- "integrity": "sha512-E7tDoyAA8ilZIV3xDJgl18sX3M8xB9/fMw8+mfW4msLW8jlX97bAnWgT3pmaNXuvzIEgSBMnAHfuXsB2hdzfow==",
+ "version": "4.7.7",
+ "resolved": "https://registry.npmjs.org/handlebars/-/handlebars-4.7.7.tgz",
+ "integrity": "sha512-aAcXm5OAfE/8IXkcZvCepKU3VzW1/39Fb5ZuqMtgI/hT8X2YgoMvBY5dLhq/cpOvw7Lk1nK/UF71aLG/ZnVYRA==",
"requires": {
- "async": "^2.5.0",
- "optimist": "^0.6.1",
+ "minimist": "^1.2.5",
+ "neo-async": "^2.6.0",
"source-map": "^0.6.1",
- "uglify-js": "^3.1.4"
+ "uglify-js": "^3.1.4",
+ "wordwrap": "^1.0.0"
}
},
"hbs": {
- "version": "4.0.4",
- "resolved": "https://registry.npmjs.org/hbs/-/hbs-4.0.4.tgz",
- "integrity": "sha512-esVlyV/V59mKkwFai5YmPRSNIWZzhqL5YMN0++ueMxyK1cCfPa5f6JiHtapPKAIVAhQR6rpGxow0troav9WMEg==",
+ "version": "4.1.2",
+ "resolved": "https://registry.npmjs.org/hbs/-/hbs-4.1.2.tgz",
+ "integrity": "sha512-WfBnQbozbdiTLjJu6P6Wturgvy0FN8xtRmIjmP0ebX9OGQrt+2S6UC7xX0IebHTCS1sXe20zfTzQ7yhjrEvrfQ==",
"requires": {
- "handlebars": "4.0.14",
- "walk": "2.3.9"
+ "handlebars": "4.7.7",
+ "walk": "2.3.14"
}
},
"http-errors": {
- "version": "1.6.3",
- "resolved": "http://registry.npmjs.org/http-errors/-/http-errors-1.6.3.tgz",
- "integrity": "sha1-i1VoC7S+KDoLW/TqLjhYC+HZMg0=",
+ "version": "1.7.2",
+ "resolved": "https://registry.npmjs.org/http-errors/-/http-errors-1.7.2.tgz",
+ "integrity": "sha512-uUQBt3H/cSIVfch6i1EuPNy/YsRSOUBXTVfZ+yR7Zjez3qjBz6i9+i4zjNaoqcoFVI4lQJ5plg63TvGfRSDCRg==",
"requires": {
"depd": "~1.1.2",
"inherits": "2.0.3",
- "setprototypeof": "1.1.0",
- "statuses": ">= 1.4.0 < 2"
+ "setprototypeof": "1.1.1",
+ "statuses": ">= 1.5.0 < 2",
+ "toidentifier": "1.0.0"
}
},
"iconv-lite": {
- "version": "0.4.23",
- "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.4.23.tgz",
- "integrity": "sha512-neyTUVFtahjf0mB3dZT77u+8O0QB89jFdnBkd5P1JgYPbPaia3gXXOVL2fq8VyU2gMMD7SaN7QukTB/pmXYvDA==",
+ "version": "0.4.24",
+ "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.4.24.tgz",
+ "integrity": "sha512-v3MXnZAcvnywkTUEZomIActle7RXXeedOR31wwl7VlyoXO4Qi9arvSenNQWne1TcRwhCL1HwLI21bEqdpj8/rA==",
"requires": {
"safer-buffer": ">= 2.1.2 < 3"
}
@@ -228,18 +854,13 @@
"integrity": "sha1-Yzwsg+PaQqUC9SRmAiSA9CCCYd4="
},
"ipaddr.js": {
- "version": "1.8.0",
- "resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.8.0.tgz",
- "integrity": "sha1-6qM9bd16zo9/b+DJygRA5wZzix4="
- },
- "lodash": {
- "version": "4.17.15",
- "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.15.tgz",
- "integrity": "sha512-8xOcRHvCjnocdS5cpwXQXVzmmh5e5+saE2QGoeQmbKmRS6J3VQppPOIt0MnmE+4xlZoumy0GPG0D0MVIQbNA1A=="
+ "version": "1.9.1",
+ "resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.9.1.tgz",
+ "integrity": "sha512-0KI/607xoxSToH7GjN1FfSbLoU0+btTicjsQSWQlh/hZykN8KpmMf7uYwPW3R+akZ6R/w18ZlXSHBYXiYUPO3g=="
},
"media-typer": {
"version": "0.3.0",
- "resolved": "http://registry.npmjs.org/media-typer/-/media-typer-0.3.0.tgz",
+ "resolved": "https://registry.npmjs.org/media-typer/-/media-typer-0.3.0.tgz",
"integrity": "sha1-hxDXrwqmJvj/+hzgAWhUUmMlV0g="
},
"merge-descriptors": {
@@ -253,27 +874,27 @@
"integrity": "sha1-VSmk1nZUE07cxSZmVoNbD4Ua/O4="
},
"mime": {
- "version": "1.4.1",
- "resolved": "https://registry.npmjs.org/mime/-/mime-1.4.1.tgz",
- "integrity": "sha512-KI1+qOZu5DcW6wayYHSzR/tXKCDC5Om4s1z2QJjDULzLcmf3DvzS7oluY4HCTrc+9FiKmWUgeNLg7W3uIQvxtQ=="
+ "version": "1.6.0",
+ "resolved": "https://registry.npmjs.org/mime/-/mime-1.6.0.tgz",
+ "integrity": "sha512-x0Vn8spI+wuJ1O6S7gnbaQg8Pxh4NNHb7KSINmEWKiPE4RKOplvijn+NkmYmmRgP68mc70j2EbeTFRsrswaQeg=="
},
"mime-db": {
- "version": "1.37.0",
- "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.37.0.tgz",
- "integrity": "sha512-R3C4db6bgQhlIhPU48fUtdVmKnflq+hRdad7IyKhtFj06VPNVdk2RhiYL3UjQIlso8L+YxAtFkobT0VK+S/ybg=="
+ "version": "1.49.0",
+ "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.49.0.tgz",
+ "integrity": "sha512-CIc8j9URtOVApSFCQIF+VBkX1RwXp/oMMOrqdyXSBXq5RWNEsRfyj1kiRnQgmNXmHxPoFIxOroKA3zcU9P+nAA=="
},
"mime-types": {
- "version": "2.1.21",
- "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.21.tgz",
- "integrity": "sha512-3iL6DbwpyLzjR3xHSFNFeb9Nz/M8WDkX33t1GFQnFOllWk8pOrh/LSrB5OXlnlW5P9LH73X6loW/eogc+F5lJg==",
+ "version": "2.1.32",
+ "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.32.tgz",
+ "integrity": "sha512-hJGaVS4G4c9TSMYh2n6SQAGrC4RnfU+daP8G7cSCmaqNjiOoUY0VHCMS42pxnQmVF1GWwFhbHWn3RIxCqTmZ9A==",
"requires": {
- "mime-db": "~1.37.0"
+ "mime-db": "1.49.0"
}
},
"minimist": {
- "version": "0.0.10",
- "resolved": "https://registry.npmjs.org/minimist/-/minimist-0.0.10.tgz",
- "integrity": "sha1-3j+YVD2/lggr5IrRoMfNqDYwHc8="
+ "version": "1.2.5",
+ "resolved": "https://registry.npmjs.org/minimist/-/minimist-1.2.5.tgz",
+ "integrity": "sha512-FM9nNUYrRBAELZQT3xeZQ7fmMOBg6nWNmJKTcgsJeaLstP/UODVpGsr5OhXhhXg6f+qtJ8uiZ+PUxkDWcgIXLw=="
},
"ms": {
"version": "2.0.0",
@@ -281,9 +902,14 @@
"integrity": "sha1-VgiurfwAvmwpAd9fmGF4jeDVl8g="
},
"negotiator": {
- "version": "0.6.1",
- "resolved": "https://registry.npmjs.org/negotiator/-/negotiator-0.6.1.tgz",
- "integrity": "sha1-KzJxhOiZIQEXeyhWP7XnECrNDKk="
+ "version": "0.6.2",
+ "resolved": "https://registry.npmjs.org/negotiator/-/negotiator-0.6.2.tgz",
+ "integrity": "sha512-hZXc7K2e+PgeI1eDBe/10Ard4ekbfrrqG8Ep+8Jmf4JID2bNg7NvCPOZN+kfF574pFQI7mum2AUqDidoKqcTOw=="
+ },
+ "neo-async": {
+ "version": "2.6.2",
+ "resolved": "https://registry.npmjs.org/neo-async/-/neo-async-2.6.2.tgz",
+ "integrity": "sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw=="
},
"on-finished": {
"version": "2.3.0",
@@ -293,19 +919,10 @@
"ee-first": "1.1.1"
}
},
- "optimist": {
- "version": "0.6.1",
- "resolved": "https://registry.npmjs.org/optimist/-/optimist-0.6.1.tgz",
- "integrity": "sha1-2j6nRob6IaGaERwybpDrFaAZZoY=",
- "requires": {
- "minimist": "~0.0.1",
- "wordwrap": "~0.0.2"
- }
- },
"parseurl": {
- "version": "1.3.2",
- "resolved": "https://registry.npmjs.org/parseurl/-/parseurl-1.3.2.tgz",
- "integrity": "sha1-/CidTtiZMRlGDBViUyYs3I3mW/M="
+ "version": "1.3.3",
+ "resolved": "https://registry.npmjs.org/parseurl/-/parseurl-1.3.3.tgz",
+ "integrity": "sha512-CiyeOxFT/JZyN5m0z9PfXw4SCBJ6Sygz1Dpl0wqjlhDEGGBP1GnsUVEL0p63hoG1fcj3fHynXi9NYO4nWOL+qQ=="
},
"path-to-regexp": {
"version": "0.1.7",
@@ -313,61 +930,65 @@
"integrity": "sha1-32BBeABfUi8V60SQ5yR6G/qmf4w="
},
"proxy-addr": {
- "version": "2.0.4",
- "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.4.tgz",
- "integrity": "sha512-5erio2h9jp5CHGwcybmxmVqHmnCBZeewlfJ0pex+UW7Qny7OOZXTtH56TGNyBizkgiOwhJtMKrVzDTeKcySZwA==",
+ "version": "2.0.7",
+ "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.7.tgz",
+ "integrity": "sha512-llQsMLSUDUPT44jdrU/O37qlnifitDP+ZwrmmZcoSKyLKvtZxpyV0n2/bD/N4tBAAZ/gJEdZU7KMraoK1+XYAg==",
"requires": {
- "forwarded": "~0.1.2",
- "ipaddr.js": "1.8.0"
+ "forwarded": "0.2.0",
+ "ipaddr.js": "1.9.1"
}
},
"qs": {
- "version": "6.5.2",
- "resolved": "https://registry.npmjs.org/qs/-/qs-6.5.2.tgz",
- "integrity": "sha512-N5ZAX4/LxJmF+7wN74pUD6qAh9/wnvdQcjq9TZjevvXzSUo7bfmw91saqMjzGS2xq91/odN2dW/WOl7qQHNDGA=="
+ "version": "6.7.0",
+ "resolved": "https://registry.npmjs.org/qs/-/qs-6.7.0.tgz",
+ "integrity": "sha512-VCdBRNFTX1fyE7Nb6FYoURo/SPe62QCaAyzJvUjwRaIsc+NePBEniHlvxFmmX56+HZphIGtV0XeCirBtpDrTyQ=="
},
"range-parser": {
- "version": "1.2.0",
- "resolved": "https://registry.npmjs.org/range-parser/-/range-parser-1.2.0.tgz",
- "integrity": "sha1-9JvmtIeJTdxA3MlKMi9hEJLgDV4="
+ "version": "1.2.1",
+ "resolved": "https://registry.npmjs.org/range-parser/-/range-parser-1.2.1.tgz",
+ "integrity": "sha512-Hrgsx+orqoygnmhFbKaHE6c296J+HTAQXoxEF6gNupROmmGJRoyzfG3ccAveqCBrwr/2yxQ5BVd/GTl5agOwSg=="
},
"raw-body": {
- "version": "2.3.3",
- "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.3.3.tgz",
- "integrity": "sha512-9esiElv1BrZoI3rCDuOuKCBRbuApGGaDPQfjSflGxdy4oyzqghxu6klEkkVIvBje+FF0BX9coEv8KqW6X/7njw==",
+ "version": "2.4.0",
+ "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.4.0.tgz",
+ "integrity": "sha512-4Oz8DUIwdvoa5qMJelxipzi/iJIi40O5cGV1wNYp5hvZP8ZN0T+jiNkL0QepXs+EsQ9XJ8ipEDoiH70ySUJP3Q==",
"requires": {
- "bytes": "3.0.0",
- "http-errors": "1.6.3",
- "iconv-lite": "0.4.23",
+ "bytes": "3.1.0",
+ "http-errors": "1.7.2",
+ "iconv-lite": "0.4.24",
"unpipe": "1.0.0"
}
},
"redis": {
- "version": "2.8.0",
- "resolved": "https://registry.npmjs.org/redis/-/redis-2.8.0.tgz",
- "integrity": "sha512-M1OkonEQwtRmZv4tEWF2VgpG0JWJ8Fv1PhlgT5+B+uNq2cA3Rt1Yt/ryoR+vQNOQcIEgdCdfH0jr3bDpihAw1A==",
+ "version": "3.1.2",
+ "resolved": "https://registry.npmjs.org/redis/-/redis-3.1.2.tgz",
+ "integrity": "sha512-grn5KoZLr/qrRQVwoSkmzdbw6pwF+/rwODtrOr6vuBRiR/f3rjSTGupbF90Zpqm2oenix8Do6RV7pYEkGwlKkw==",
"requires": {
- "double-ended-queue": "^2.1.0-0",
- "redis-commands": "^1.2.0",
- "redis-parser": "^2.6.0"
+ "denque": "^1.5.0",
+ "redis-commands": "^1.7.0",
+ "redis-errors": "^1.2.0",
+ "redis-parser": "^3.0.0"
},
"dependencies": {
- "redis-commands": {
- "version": "1.4.0",
- "resolved": "https://registry.npmjs.org/redis-commands/-/redis-commands-1.4.0.tgz",
- "integrity": "sha512-cu8EF+MtkwI4DLIT0x9P8qNTLFhQD4jLfxLR0cCNkeGzs87FN6879JOJwNQR/1zD7aSYNbU0hgsV9zGY71Itvw=="
- },
"redis-parser": {
- "version": "2.6.0",
- "resolved": "https://registry.npmjs.org/redis-parser/-/redis-parser-2.6.0.tgz",
- "integrity": "sha1-Uu0J2srBCPGmMcB+m2mUHnoZUEs="
+ "version": "3.0.0",
+ "resolved": "https://registry.npmjs.org/redis-parser/-/redis-parser-3.0.0.tgz",
+ "integrity": "sha1-tm2CjNyv5rS4pCin3vTGvKwxyLQ=",
+ "requires": {
+ "redis-errors": "^1.0.0"
+ }
}
}
},
"redis-commands": {
- "version": "1.5.0",
- "resolved": "https://registry.npmjs.org/redis-commands/-/redis-commands-1.5.0.tgz",
- "integrity": "sha512-6KxamqpZ468MeQC3bkWmCB1fp56XL64D4Kf0zJSwDZbVLLm7KFkoIcHrgRvQ+sk8dnhySs7+yBg94yIkAK7aJg=="
+ "version": "1.7.0",
+ "resolved": "https://registry.npmjs.org/redis-commands/-/redis-commands-1.7.0.tgz",
+ "integrity": "sha512-nJWqw3bTFy21hX/CPKHth6sfhZbdiHP6bTawSgQBlKOVRG7EZkfHbbHwQJnrE4vsQf0CMNE+3gJ4Fmm16vdVlQ=="
+ },
+ "redis-errors": {
+ "version": "1.2.0",
+ "resolved": "https://registry.npmjs.org/redis-errors/-/redis-errors-1.2.0.tgz",
+ "integrity": "sha1-62LSrbFeTq9GEMBK/hUpOEJQq60="
},
"redis-parser": {
"version": "2.6.0",
@@ -385,14 +1006,14 @@
"integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg=="
},
"secure-random-string": {
- "version": "1.1.0",
- "resolved": "https://registry.npmjs.org/secure-random-string/-/secure-random-string-1.1.0.tgz",
- "integrity": "sha512-V/h8jqoz58zklNGybVhP++cWrxEPXlLM/6BeJ4e0a8zlb4BsbYRzFs16snrxByPa5LUxCVTD3M6EYIVIHR1fAg=="
+ "version": "1.1.3",
+ "resolved": "https://registry.npmjs.org/secure-random-string/-/secure-random-string-1.1.3.tgz",
+ "integrity": "sha512-298HxkJJp5mjpPhxDsN26S/2JmMaUIrQ4PxDI/F4fXKRBTOKendQ5i6JCkc+a8F8koLh0vdfwSCw8+RJkY7N6A=="
},
"send": {
- "version": "0.16.2",
- "resolved": "https://registry.npmjs.org/send/-/send-0.16.2.tgz",
- "integrity": "sha512-E64YFPUssFHEFBvpbbjr44NCLtI1AohxQ8ZSiJjQLskAdKuriYEP6VyGEsRDH8ScozGpkaX1BGvhanqCwkcEZw==",
+ "version": "0.17.1",
+ "resolved": "https://registry.npmjs.org/send/-/send-0.17.1.tgz",
+ "integrity": "sha512-BsVKsiGcQMFwT8UxypobUKyv7irCNRHk1T0G680vk88yf6LBByGcZJOTJCrTP2xVN6yI+XjPJcNuE3V4fT9sAg==",
"requires": {
"debug": "2.6.9",
"depd": "~1.1.2",
@@ -401,29 +1022,36 @@
"escape-html": "~1.0.3",
"etag": "~1.8.1",
"fresh": "0.5.2",
- "http-errors": "~1.6.2",
- "mime": "1.4.1",
- "ms": "2.0.0",
+ "http-errors": "~1.7.2",
+ "mime": "1.6.0",
+ "ms": "2.1.1",
"on-finished": "~2.3.0",
- "range-parser": "~1.2.0",
- "statuses": "~1.4.0"
+ "range-parser": "~1.2.1",
+ "statuses": "~1.5.0"
+ },
+ "dependencies": {
+ "ms": {
+ "version": "2.1.1",
+ "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.1.tgz",
+ "integrity": "sha512-tgp+dl5cGk28utYktBsrFqA7HKgrhgPsg6Z/EfhWI4gl1Hwq8B/GmY/0oXZ6nF8hDVesS/FpnYaD/kOWhYQvyg=="
+ }
}
},
"serve-static": {
- "version": "1.13.2",
- "resolved": "https://registry.npmjs.org/serve-static/-/serve-static-1.13.2.tgz",
- "integrity": "sha512-p/tdJrO4U387R9oMjb1oj7qSMaMfmOyd4j9hOFoxZe2baQszgHcSWjuya/CiT5kgZZKRudHNOA0pYXOl8rQ5nw==",
+ "version": "1.14.1",
+ "resolved": "https://registry.npmjs.org/serve-static/-/serve-static-1.14.1.tgz",
+ "integrity": "sha512-JMrvUwE54emCYWlTI+hGrGv5I8dEwmco/00EvkzIIsR7MqrHonbD9pO2MOfFnpFntl7ecpZs+3mW+XbQZu9QCg==",
"requires": {
"encodeurl": "~1.0.2",
"escape-html": "~1.0.3",
- "parseurl": "~1.3.2",
- "send": "0.16.2"
+ "parseurl": "~1.3.3",
+ "send": "0.17.1"
}
},
"setprototypeof": {
- "version": "1.1.0",
- "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.1.0.tgz",
- "integrity": "sha512-BvE/TwpZX4FXExxOxZyRGQQv651MSwmWKZGqvmPcRIjDqWub67kTKuIMx43cZZrS/cBBzwBcNDWoFxt2XEFIpQ=="
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.1.1.tgz",
+ "integrity": "sha512-JvdAWfbXeIGaZ9cILp38HntZSFSo3mWg6xGcJJsd+d4aRMOqauag1C63dJfDw7OaMYwEbHMOxEZ1lqVRYP2OAw=="
},
"source-map": {
"version": "0.6.1",
@@ -431,28 +1059,29 @@
"integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g=="
},
"statuses": {
- "version": "1.4.0",
- "resolved": "https://registry.npmjs.org/statuses/-/statuses-1.4.0.tgz",
- "integrity": "sha512-zhSCtt8v2NDrRlPQpCNtw/heZLtfUDqxBM1udqikb/Hbk52LK4nQSwr10u77iopCW5LsyHpuXS0GnEc48mLeew=="
+ "version": "1.5.0",
+ "resolved": "https://registry.npmjs.org/statuses/-/statuses-1.5.0.tgz",
+ "integrity": "sha1-Fhx9rBd2Wf2YEfQ3cfqZOBR4Yow="
+ },
+ "toidentifier": {
+ "version": "1.0.0",
+ "resolved": "https://registry.npmjs.org/toidentifier/-/toidentifier-1.0.0.tgz",
+ "integrity": "sha512-yaOH/Pk/VEhBWWTlhI+qXxDFXlejDGcQipMlyxda9nthulaxLZUNcUqFxokp0vcYnvteJln5FNQDRrxj3YcbVw=="
},
"type-is": {
- "version": "1.6.16",
- "resolved": "https://registry.npmjs.org/type-is/-/type-is-1.6.16.tgz",
- "integrity": "sha512-HRkVv/5qY2G6I8iab9cI7v1bOIdhm94dVjQCPFElW9W+3GeDOSHmy2EBYe4VTApuzolPcmgFTN3ftVJRKR2J9Q==",
+ "version": "1.6.18",
+ "resolved": "https://registry.npmjs.org/type-is/-/type-is-1.6.18.tgz",
+ "integrity": "sha512-TkRKr9sUTxEH8MdfuCSP7VizJyzRNMjj2J2do2Jr3Kym598JVdEksuzPQCnlFPW4ky9Q+iA+ma9BGm06XQBy8g==",
"requires": {
"media-typer": "0.3.0",
- "mime-types": "~2.1.18"
+ "mime-types": "~2.1.24"
}
},
"uglify-js": {
- "version": "3.5.9",
- "resolved": "https://registry.npmjs.org/uglify-js/-/uglify-js-3.5.9.tgz",
- "integrity": "sha512-WpT0RqsDtAWPNJK955DEnb6xjymR8Fn0OlK4TT4pS0ASYsVPqr5ELhgwOwLCP5J5vHeJ4xmMmz3DEgdqC10JeQ==",
- "optional": true,
- "requires": {
- "commander": "~2.20.0",
- "source-map": "~0.6.1"
- }
+ "version": "3.14.2",
+ "resolved": "https://registry.npmjs.org/uglify-js/-/uglify-js-3.14.2.tgz",
+ "integrity": "sha512-rtPMlmcO4agTUfz10CbgJ1k6UAoXM2gWb3GoMPPZB/+/Ackf8lNWk11K4rYi2D0apgoFRLtQOZhb+/iGNJq26A==",
+ "optional": true
},
"unpipe": {
"version": "1.0.0",
@@ -470,17 +1099,17 @@
"integrity": "sha1-IpnwLG3tMNSllhsLn3RSShj2NPw="
},
"walk": {
- "version": "2.3.9",
- "resolved": "https://registry.npmjs.org/walk/-/walk-2.3.9.tgz",
- "integrity": "sha1-MbTbZnjyrgHDnqn7hyWpAx5Vins=",
+ "version": "2.3.14",
+ "resolved": "https://registry.npmjs.org/walk/-/walk-2.3.14.tgz",
+ "integrity": "sha512-5skcWAUmySj6hkBdH6B6+3ddMjVQYH5Qy9QGbPmN8kVmLteXk+yVXg+yfk1nbX30EYakahLrr8iPcCxJQSCBeg==",
"requires": {
"foreachasync": "^3.0.0"
}
},
"wordwrap": {
- "version": "0.0.3",
- "resolved": "https://registry.npmjs.org/wordwrap/-/wordwrap-0.0.3.tgz",
- "integrity": "sha1-o9XabNXAvAAI03I0u68b7WMFkQc="
+ "version": "1.0.0",
+ "resolved": "https://registry.npmjs.org/wordwrap/-/wordwrap-1.0.0.tgz",
+ "integrity": "sha1-J1hIEIkUVqQXHI0CJkQa3pDLyus="
}
}
}
diff --git a/images/benchmarks/node/package.json b/images/benchmarks/node/package.json
index 7dcadd523..a46adc22a 100644
--- a/images/benchmarks/node/package.json
+++ b/images/benchmarks/node/package.json
@@ -11,7 +11,7 @@
"dependencies": {
"express": "^4.16.4",
"hbs": "^4.0.4",
- "redis": "^2.8.0",
+ "redis": "^3.1.2",
"redis-commands": "^1.2.0",
"redis-parser": "^2.6.0",
"secure-random-string": "^1.1.0"
diff --git a/nogo.yaml b/nogo.yaml
index 64332cb41..f26493fcb 100644
--- a/nogo.yaml
+++ b/nogo.yaml
@@ -179,12 +179,6 @@ analyzers:
internal:
exclude:
- pkg/gohacks/gohacks_unsafe.go # x ^ 0 always equals x.
- SA5011: # Possible nil pointer dereference.
- internal:
- exclude:
- # https://github.com/dominikh/go-tools/issues/924
- - pkg/sentry/fs/fdpipe/pipe_opener_test.go
- - pkg/tcpip/tests/integration/link_resolution_test.go
ST1019: # Multiple imports of the same package.
generated:
exclude:
diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go
index 1e23850a9..67646f837 100644
--- a/pkg/abi/linux/file.go
+++ b/pkg/abi/linux/file.go
@@ -242,7 +242,7 @@ const (
// Statx represents struct statx.
//
-// +marshal
+// +marshal slice:StatxSlice
type Statx struct {
Mask uint32
Blksize uint32
@@ -270,6 +270,8 @@ type Statx struct {
var SizeOfStatx = (*Statx)(nil).SizeBytes()
// FileMode represents a mode_t.
+//
+// +marshal
type FileMode uint16
// Permissions returns just the permission bits.
diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go
index 29062c97a..4526d3f95 100644
--- a/pkg/abi/linux/ioctl.go
+++ b/pkg/abi/linux/ioctl.go
@@ -180,8 +180,12 @@ var (
// to get quote.
const SizeOfQuoteInputData = 64
-// SignReport is a struct that gets signed quote from input data.
+// SignReport is a struct that gets signed quote from input data. The
+// serialized quote is copied to buf.
+// size is an input that specifies the size of buf. When returned, it's updated
+// to the size of quote.
type SignReport struct {
- data [64]byte
- quote []byte
+ data [64]byte
+ size uint32
+ buf []byte
}
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index 95871b8a5..f60e42997 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -542,6 +542,15 @@ type ControlMessageIPPacketInfo struct {
DestinationAddr InetAddr
}
+// ControlMessageIPv6PacketInfo represents struct in6_pktinfo from linux/ipv6.h.
+//
+// +marshal
+// +stateify savable
+type ControlMessageIPv6PacketInfo struct {
+ Addr Inet6Addr
+ NIC uint32
+}
+
// SizeOfControlMessageCredentials is the binary size of a
// ControlMessageCredentials struct.
var SizeOfControlMessageCredentials = (*ControlMessageCredentials)(nil).SizeBytes()
@@ -566,6 +575,10 @@ const SizeOfControlMessageTClass = 4
// control message.
const SizeOfControlMessageIPPacketInfo = 12
+// SizeOfControlMessageIPv6PacketInfo is the size of a
+// ControlMessageIPv6PacketInfo.
+const SizeOfControlMessageIPv6PacketInfo = 20
+
// SCM_MAX_FD is the maximum number of FDs accepted in a single sendmsg call.
// From net/scm.h.
const SCM_MAX_FD = 253
diff --git a/pkg/atomicbitops/atomicbitops_amd64.s b/pkg/atomicbitops/atomicbitops_amd64.s
index 54c887ee5..6b9a67adc 100644
--- a/pkg/atomicbitops/atomicbitops_amd64.s
+++ b/pkg/atomicbitops/atomicbitops_amd64.s
@@ -16,28 +16,28 @@
#include "textflag.h"
-TEXT ·AndUint32(SB),$0-12
- MOVQ addr+0(FP), BP
+TEXT ·AndUint32(SB),NOSPLIT,$0-12
+ MOVQ addr+0(FP), BX
MOVL val+8(FP), AX
LOCK
- ANDL AX, 0(BP)
+ ANDL AX, 0(BX)
RET
-TEXT ·OrUint32(SB),$0-12
- MOVQ addr+0(FP), BP
+TEXT ·OrUint32(SB),NOSPLIT,$0-12
+ MOVQ addr+0(FP), BX
MOVL val+8(FP), AX
LOCK
- ORL AX, 0(BP)
+ ORL AX, 0(BX)
RET
-TEXT ·XorUint32(SB),$0-12
- MOVQ addr+0(FP), BP
+TEXT ·XorUint32(SB),NOSPLIT,$0-12
+ MOVQ addr+0(FP), BX
MOVL val+8(FP), AX
LOCK
- XORL AX, 0(BP)
+ XORL AX, 0(BX)
RET
-TEXT ·CompareAndSwapUint32(SB),$0-20
+TEXT ·CompareAndSwapUint32(SB),NOSPLIT,$0-20
MOVQ addr+0(FP), DI
MOVL old+8(FP), AX
MOVL new+12(FP), DX
@@ -46,28 +46,28 @@ TEXT ·CompareAndSwapUint32(SB),$0-20
MOVL AX, ret+16(FP)
RET
-TEXT ·AndUint64(SB),$0-16
- MOVQ addr+0(FP), BP
+TEXT ·AndUint64(SB),NOSPLIT,$0-16
+ MOVQ addr+0(FP), BX
MOVQ val+8(FP), AX
LOCK
- ANDQ AX, 0(BP)
+ ANDQ AX, 0(BX)
RET
-TEXT ·OrUint64(SB),$0-16
- MOVQ addr+0(FP), BP
+TEXT ·OrUint64(SB),NOSPLIT,$0-16
+ MOVQ addr+0(FP), BX
MOVQ val+8(FP), AX
LOCK
- ORQ AX, 0(BP)
+ ORQ AX, 0(BX)
RET
-TEXT ·XorUint64(SB),$0-16
- MOVQ addr+0(FP), BP
+TEXT ·XorUint64(SB),NOSPLIT,$0-16
+ MOVQ addr+0(FP), BX
MOVQ val+8(FP), AX
LOCK
- XORQ AX, 0(BP)
+ XORQ AX, 0(BX)
RET
-TEXT ·CompareAndSwapUint64(SB),$0-32
+TEXT ·CompareAndSwapUint64(SB),NOSPLIT,$0-32
MOVQ addr+0(FP), DI
MOVQ old+8(FP), AX
MOVQ new+16(FP), DX
diff --git a/pkg/atomicbitops/atomicbitops_arm64.s b/pkg/atomicbitops/atomicbitops_arm64.s
index 5c780851b..644a6bca5 100644
--- a/pkg/atomicbitops/atomicbitops_arm64.s
+++ b/pkg/atomicbitops/atomicbitops_arm64.s
@@ -16,7 +16,7 @@
#include "textflag.h"
-TEXT ·AndUint32(SB),$0-12
+TEXT ·AndUint32(SB),NOSPLIT,$0-12
MOVD ptr+0(FP), R0
MOVW val+8(FP), R1
again:
@@ -26,7 +26,7 @@ again:
CBNZ R3, again
RET
-TEXT ·OrUint32(SB),$0-12
+TEXT ·OrUint32(SB),NOSPLIT,$0-12
MOVD ptr+0(FP), R0
MOVW val+8(FP), R1
again:
@@ -36,7 +36,7 @@ again:
CBNZ R3, again
RET
-TEXT ·XorUint32(SB),$0-12
+TEXT ·XorUint32(SB),NOSPLIT,$0-12
MOVD ptr+0(FP), R0
MOVW val+8(FP), R1
again:
@@ -46,7 +46,7 @@ again:
CBNZ R3, again
RET
-TEXT ·CompareAndSwapUint32(SB),$0-20
+TEXT ·CompareAndSwapUint32(SB),NOSPLIT,$0-20
MOVD addr+0(FP), R0
MOVW old+8(FP), R1
MOVW new+12(FP), R2
@@ -60,7 +60,7 @@ done:
MOVW R3, prev+16(FP)
RET
-TEXT ·AndUint64(SB),$0-16
+TEXT ·AndUint64(SB),NOSPLIT,$0-16
MOVD ptr+0(FP), R0
MOVD val+8(FP), R1
again:
@@ -70,7 +70,7 @@ again:
CBNZ R3, again
RET
-TEXT ·OrUint64(SB),$0-16
+TEXT ·OrUint64(SB),NOSPLIT,$0-16
MOVD ptr+0(FP), R0
MOVD val+8(FP), R1
again:
@@ -80,7 +80,7 @@ again:
CBNZ R3, again
RET
-TEXT ·XorUint64(SB),$0-16
+TEXT ·XorUint64(SB),NOSPLIT,$0-16
MOVD ptr+0(FP), R0
MOVD val+8(FP), R1
again:
@@ -90,7 +90,7 @@ again:
CBNZ R3, again
RET
-TEXT ·CompareAndSwapUint64(SB),$0-32
+TEXT ·CompareAndSwapUint64(SB),NOSPLIT,$0-32
MOVD addr+0(FP), R0
MOVD old+8(FP), R1
MOVD new+16(FP), R2
diff --git a/pkg/atomicbitops/atomicbitops_noasm.go b/pkg/atomicbitops/atomicbitops_noasm.go
index 474c0c815..af6b1362e 100644
--- a/pkg/atomicbitops/atomicbitops_noasm.go
+++ b/pkg/atomicbitops/atomicbitops_noasm.go
@@ -21,6 +21,7 @@ import (
"sync/atomic"
)
+//go:nosplit
func AndUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -31,6 +32,7 @@ func AndUint32(addr *uint32, val uint32) {
}
}
+//go:nosplit
func OrUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -41,6 +43,7 @@ func OrUint32(addr *uint32, val uint32) {
}
}
+//go:nosplit
func XorUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -51,6 +54,7 @@ func XorUint32(addr *uint32, val uint32) {
}
}
+//go:nosplit
func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) {
for {
prev = atomic.LoadUint32(addr)
@@ -63,6 +67,7 @@ func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) {
}
}
+//go:nosplit
func AndUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -73,6 +78,7 @@ func AndUint64(addr *uint64, val uint64) {
}
}
+//go:nosplit
func OrUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -83,6 +89,7 @@ func OrUint64(addr *uint64, val uint64) {
}
}
+//go:nosplit
func XorUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -93,6 +100,7 @@ func XorUint64(addr *uint64, val uint64) {
}
}
+//go:nosplit
func CompareAndSwapUint64(addr *uint64, old, new uint64) (prev uint64) {
for {
prev = atomic.LoadUint64(addr)
diff --git a/pkg/buffer/view_test.go b/pkg/buffer/view_test.go
index 796efa240..59784eacb 100644
--- a/pkg/buffer/view_test.go
+++ b/pkg/buffer/view_test.go
@@ -509,6 +509,24 @@ func TestView(t *testing.T) {
}
}
+func TestViewClone(t *testing.T) {
+ const (
+ originalSize = 90
+ bytesToDelete = 30
+ )
+ var v View
+ v.AppendOwned(bytes.Repeat([]byte{originalSize}, originalSize))
+
+ clonedV := v.Clone()
+ v.TrimFront(bytesToDelete)
+ if got, want := int(v.Size()), originalSize-bytesToDelete; got != want {
+ t.Errorf("original packet was not changed: size expected = %d, got = %d", want, got)
+ }
+ if got := clonedV.Size(); got != originalSize {
+ t.Errorf("cloned packet should not be modified: expected size = %d, got = %d", originalSize, got)
+ }
+}
+
func TestViewPullUp(t *testing.T) {
for _, tc := range []struct {
desc string
diff --git a/pkg/crypto/crypto_stdlib.go b/pkg/crypto/crypto_stdlib.go
index 69e867386..28eba2ff6 100644
--- a/pkg/crypto/crypto_stdlib.go
+++ b/pkg/crypto/crypto_stdlib.go
@@ -19,14 +19,21 @@ package crypto
import (
"crypto/ecdsa"
+ "crypto/elliptic"
"crypto/sha512"
+ "fmt"
"math/big"
)
-// EcdsaVerify verifies the signature in r, s of hash using ECDSA and the
-// public key, pub. Its return value records whether the signature is valid.
-func EcdsaVerify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) (bool, error) {
- return ecdsa.Verify(pub, hash, r, s), nil
+// EcdsaP384Sha384Verify verifies the signature in r, s of hash using ECDSA
+// P384 + SHA 384 and the public key, pub. Its return value records whether
+// the signature is valid.
+func EcdsaP384Sha384Verify(pub *ecdsa.PublicKey, data []byte, r, s *big.Int) (bool, error) {
+ if pub.Curve != elliptic.P384() {
+ return false, fmt.Errorf("unsupported key curve: want P-384, got %v", pub.Curve)
+ }
+ digest := sha512.Sum384(data)
+ return ecdsa.Verify(pub, digest[:], r, s), nil
}
// SumSha384 returns the SHA384 checksum of the data.
diff --git a/pkg/eventfd/BUILD b/pkg/eventfd/BUILD
new file mode 100644
index 000000000..02407cb99
--- /dev/null
+++ b/pkg/eventfd/BUILD
@@ -0,0 +1,22 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "eventfd",
+ srcs = [
+ "eventfd.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/hostarch",
+ "//pkg/tcpip/link/rawfile",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "eventfd_test",
+ srcs = ["eventfd_test.go"],
+ library = ":eventfd",
+)
diff --git a/pkg/eventfd/eventfd.go b/pkg/eventfd/eventfd.go
new file mode 100644
index 000000000..acdac01b8
--- /dev/null
+++ b/pkg/eventfd/eventfd.go
@@ -0,0 +1,115 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package eventfd wraps Linux's eventfd(2) syscall.
+package eventfd
+
+import (
+ "fmt"
+ "io"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+)
+
+const sizeofUint64 = 8
+
+// Eventfd represents a Linux eventfd object.
+type Eventfd struct {
+ fd int
+}
+
+// Create returns an initialized eventfd.
+func Create() (Eventfd, error) {
+ fd, _, err := unix.RawSyscall(unix.SYS_EVENTFD2, 0, 0, 0)
+ if err != 0 {
+ return Eventfd{}, fmt.Errorf("failed to create eventfd: %v", error(err))
+ }
+ if err := unix.SetNonblock(int(fd), true); err != nil {
+ unix.Close(int(fd))
+ return Eventfd{}, err
+ }
+ return Eventfd{int(fd)}, nil
+}
+
+// Wrap returns an initialized Eventfd using the provided fd.
+func Wrap(fd int) Eventfd {
+ return Eventfd{fd}
+}
+
+// Close closes the eventfd, after which it should not be used.
+func (ev Eventfd) Close() error {
+ return unix.Close(ev.fd)
+}
+
+// Dup copies the eventfd, calling dup(2) on the underlying file descriptor.
+func (ev Eventfd) Dup() (Eventfd, error) {
+ other, err := unix.Dup(ev.fd)
+ if err != nil {
+ return Eventfd{}, fmt.Errorf("failed to dup: %v", other)
+ }
+ return Eventfd{other}, nil
+}
+
+// Notify alerts other users of the eventfd. Users can receive alerts by
+// calling Wait or Read.
+func (ev Eventfd) Notify() error {
+ return ev.Write(1)
+}
+
+// Write writes a specific value to the eventfd.
+func (ev Eventfd) Write(val uint64) error {
+ var buf [sizeofUint64]byte
+ hostarch.ByteOrder.PutUint64(buf[:], val)
+ for {
+ n, err := unix.Write(ev.fd, buf[:])
+ if err == unix.EINTR {
+ continue
+ }
+ if n != sizeofUint64 {
+ panic(fmt.Sprintf("short write to eventfd: got %d bytes, wanted %d", n, sizeofUint64))
+ }
+ return err
+ }
+}
+
+// Wait blocks until eventfd is non-zero (i.e. someone calls Notify or Write).
+func (ev Eventfd) Wait() error {
+ _, err := ev.Read()
+ return err
+}
+
+// Read blocks until eventfd is non-zero (i.e. someone calls Notify or Write)
+// and returns the value read.
+func (ev Eventfd) Read() (uint64, error) {
+ var tmp [sizeofUint64]byte
+ n, err := rawfile.BlockingReadUntranslated(ev.fd, tmp[:])
+ if err != 0 {
+ return 0, err
+ }
+ if n == 0 {
+ return 0, io.EOF
+ }
+ if n != sizeofUint64 {
+ panic(fmt.Sprintf("short read from eventfd: got %d bytes, wanted %d", n, sizeofUint64))
+ }
+ return hostarch.ByteOrder.Uint64(tmp[:]), nil
+}
+
+// FD returns the underlying file descriptor. Use with care, as this breaks the
+// Eventfd abstraction.
+func (ev Eventfd) FD() int {
+ return ev.fd
+}
diff --git a/pkg/eventfd/eventfd_test.go b/pkg/eventfd/eventfd_test.go
new file mode 100644
index 000000000..96998d530
--- /dev/null
+++ b/pkg/eventfd/eventfd_test.go
@@ -0,0 +1,75 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package eventfd
+
+import (
+ "testing"
+ "time"
+)
+
+func TestReadWrite(t *testing.T) {
+ efd, err := Create()
+ if err != nil {
+ t.Fatalf("failed to Create(): %v", err)
+ }
+ defer efd.Close()
+
+ // Make sure we can read actual values
+ const want = 343
+ if err := efd.Write(want); err != nil {
+ t.Fatalf("failed to write value: %d", want)
+ }
+
+ got, err := efd.Read()
+ if err != nil {
+ t.Fatalf("failed to read value: %v", err)
+ }
+ if got != want {
+ t.Fatalf("Read(): got %d, but wanted %d", got, want)
+ }
+}
+
+func TestWait(t *testing.T) {
+ efd, err := Create()
+ if err != nil {
+ t.Fatalf("failed to Create(): %v", err)
+ }
+ defer efd.Close()
+
+ // There's no way to test with certainty that Wait() blocks indefinitely, but
+ // as a best-effort we can wait a bit on it.
+ errCh := make(chan error)
+ go func() {
+ errCh <- efd.Wait()
+ }()
+ select {
+ case err := <-errCh:
+ t.Fatalf("Wait() returned without a call to Notify(): %v", err)
+ case <-time.After(500 * time.Millisecond):
+ }
+
+ // Notify and check that Wait() returned.
+ if err := efd.Notify(); err != nil {
+ t.Fatalf("Notify() failed: %v", err)
+ }
+ select {
+ case err := <-errCh:
+ if err != nil {
+ t.Fatalf("Read() failed: %v", err)
+ }
+ case <-time.After(5 * time.Second):
+ t.Fatalf("Read() did not return after Notify()")
+ }
+}
diff --git a/pkg/lisafs/BUILD b/pkg/lisafs/BUILD
new file mode 100644
index 000000000..313c1756d
--- /dev/null
+++ b/pkg/lisafs/BUILD
@@ -0,0 +1,117 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"],
+)
+
+go_template_instance(
+ name = "control_fd_refs",
+ out = "control_fd_refs.go",
+ package = "lisafs",
+ prefix = "controlFD",
+ template = "//pkg/refsvfs2:refs_template",
+ types = {
+ "T": "ControlFD",
+ },
+)
+
+go_template_instance(
+ name = "open_fd_refs",
+ out = "open_fd_refs.go",
+ package = "lisafs",
+ prefix = "openFD",
+ template = "//pkg/refsvfs2:refs_template",
+ types = {
+ "T": "OpenFD",
+ },
+)
+
+go_template_instance(
+ name = "control_fd_list",
+ out = "control_fd_list.go",
+ package = "lisafs",
+ prefix = "controlFD",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*ControlFD",
+ "Linker": "*ControlFD",
+ },
+)
+
+go_template_instance(
+ name = "open_fd_list",
+ out = "open_fd_list.go",
+ package = "lisafs",
+ prefix = "openFD",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*OpenFD",
+ "Linker": "*OpenFD",
+ },
+)
+
+go_library(
+ name = "lisafs",
+ srcs = [
+ "channel.go",
+ "client.go",
+ "client_file.go",
+ "communicator.go",
+ "connection.go",
+ "control_fd_list.go",
+ "control_fd_refs.go",
+ "fd.go",
+ "handlers.go",
+ "lisafs.go",
+ "message.go",
+ "open_fd_list.go",
+ "open_fd_refs.go",
+ "sample_message.go",
+ "server.go",
+ "sock.go",
+ ],
+ marshal = True,
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/cleanup",
+ "//pkg/context",
+ "//pkg/fdchannel",
+ "//pkg/flipcall",
+ "//pkg/fspath",
+ "//pkg/hostarch",
+ "//pkg/log",
+ "//pkg/marshal/primitive",
+ "//pkg/p9",
+ "//pkg/refsvfs2",
+ "//pkg/sync",
+ "//pkg/unet",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "sock_test",
+ size = "small",
+ srcs = ["sock_test.go"],
+ library = ":lisafs",
+ deps = [
+ "//pkg/marshal",
+ "//pkg/sync",
+ "//pkg/unet",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "connection_test",
+ size = "small",
+ srcs = ["connection_test.go"],
+ deps = [
+ ":lisafs",
+ "//pkg/sync",
+ "//pkg/unet",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/lisafs/README.md b/pkg/lisafs/README.md
index 51d0d40e5..6b857321a 100644
--- a/pkg/lisafs/README.md
+++ b/pkg/lisafs/README.md
@@ -1,5 +1,8 @@
# Replacing 9P
+NOTE: LISAFS is **NOT** production ready. There are still some security concerns
+that must be resolved first.
+
## Background
The Linux filesystem model consists of the following key aspects (modulo mounts,
diff --git a/pkg/lisafs/channel.go b/pkg/lisafs/channel.go
new file mode 100644
index 000000000..301212e51
--- /dev/null
+++ b/pkg/lisafs/channel.go
@@ -0,0 +1,190 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "math"
+ "runtime"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/fdchannel"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+var (
+ chanHeaderLen = uint32((*channelHeader)(nil).SizeBytes())
+)
+
+// maxChannels returns the number of channels a client can create.
+//
+// The server will reject channel creation requests beyond this (per client).
+// Note that we don't want the number of channels to be too large, because each
+// accounts for a large region of shared memory.
+// TODO(gvisor.dev/issue/6313): Tune the number of channels.
+func maxChannels() int {
+ maxChans := runtime.GOMAXPROCS(0)
+ if maxChans < 2 {
+ maxChans = 2
+ }
+ if maxChans > 4 {
+ maxChans = 4
+ }
+ return maxChans
+}
+
+// channel implements Communicator and represents the communication endpoint
+// for the client and server and is used to perform fast IPC. Apart from
+// communicating data, a channel is also capable of donating file descriptors.
+type channel struct {
+ fdTracker
+ dead bool
+ data flipcall.Endpoint
+ fdChan fdchannel.Endpoint
+}
+
+var _ Communicator = (*channel)(nil)
+
+// PayloadBuf implements Communicator.PayloadBuf.
+func (ch *channel) PayloadBuf(size uint32) []byte {
+ return ch.data.Data()[chanHeaderLen : chanHeaderLen+size]
+}
+
+// SndRcvMessage implements Communicator.SndRcvMessage.
+func (ch *channel) SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) {
+ // Write header. Requests can not donate FDs.
+ ch.marshalHdr(m, 0 /* numFDs */)
+
+ // One-shot communication. RPCs are expected to be quick rather than block.
+ rcvDataLen, err := ch.data.SendRecvFast(chanHeaderLen + payloadLen)
+ if err != nil {
+ // This channel is now unusable.
+ ch.dead = true
+ // Map the transport errors to EIO, but also log the real error.
+ log.Warningf("lisafs.sndRcvMessage: flipcall.Endpoint.SendRecv: %v", err)
+ return 0, 0, unix.EIO
+ }
+
+ return ch.rcvMsg(rcvDataLen)
+}
+
+func (ch *channel) shutdown() {
+ ch.data.Shutdown()
+}
+
+func (ch *channel) destroy() {
+ ch.dead = true
+ ch.fdChan.Destroy()
+ ch.data.Destroy()
+}
+
+// createChannel creates a server side channel. It returns a packet window
+// descriptor (for the data channel) and an open socket for the FD channel.
+func (c *Connection) createChannel(maxMessageSize uint32) (*channel, flipcall.PacketWindowDescriptor, int, error) {
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+ // If c.channels is nil, the connection has closed.
+ if c.channels == nil || len(c.channels) >= maxChannels() {
+ return nil, flipcall.PacketWindowDescriptor{}, -1, unix.ENOSYS
+ }
+ ch := &channel{}
+
+ // Set up data channel.
+ desc, err := c.channelAlloc.Allocate(flipcall.PacketHeaderBytes + int(chanHeaderLen+maxMessageSize))
+ if err != nil {
+ return nil, flipcall.PacketWindowDescriptor{}, -1, err
+ }
+ if err := ch.data.Init(flipcall.ServerSide, desc); err != nil {
+ return nil, flipcall.PacketWindowDescriptor{}, -1, err
+ }
+
+ // Set up FD channel.
+ fdSocks, err := fdchannel.NewConnectedSockets()
+ if err != nil {
+ ch.data.Destroy()
+ return nil, flipcall.PacketWindowDescriptor{}, -1, err
+ }
+ ch.fdChan.Init(fdSocks[0])
+ clientFDSock := fdSocks[1]
+
+ c.channels = append(c.channels, ch)
+ return ch, desc, clientFDSock, nil
+}
+
+// sendFDs sends as many FDs as it can. The failure to send an FD does not
+// cause an error and fail the entire RPC. FDs are considered supplementary
+// responses that are not critical to the RPC response itself. The failure to
+// send the (i)th FD will cause all the following FDs to not be sent as well
+// because the order in which FDs are donated is important.
+func (ch *channel) sendFDs(fds []int) uint8 {
+ numFDs := len(fds)
+ if numFDs == 0 {
+ return 0
+ }
+
+ if numFDs > math.MaxUint8 {
+ log.Warningf("dropping all FDs because too many FDs to donate: %v", numFDs)
+ return 0
+ }
+
+ for i, fd := range fds {
+ if err := ch.fdChan.SendFD(fd); err != nil {
+ log.Warningf("error occurred while sending (%d/%d)th FD on channel(%p): %v", i+1, numFDs, ch, err)
+ return uint8(i)
+ }
+ }
+ return uint8(numFDs)
+}
+
+// channelHeader is the header present in front of each message received on
+// flipcall endpoint when the protocol version being used is 1.
+//
+// +marshal
+type channelHeader struct {
+ message MID
+ numFDs uint8
+ _ uint8 // Need to make struct packed.
+}
+
+func (ch *channel) marshalHdr(m MID, numFDs uint8) {
+ header := &channelHeader{
+ message: m,
+ numFDs: numFDs,
+ }
+ header.MarshalUnsafe(ch.data.Data())
+}
+
+func (ch *channel) rcvMsg(dataLen uint32) (MID, uint32, error) {
+ if dataLen < chanHeaderLen {
+ log.Warningf("received data has size smaller than header length: %d", dataLen)
+ return 0, 0, unix.EIO
+ }
+
+ // Read header first.
+ var header channelHeader
+ header.UnmarshalUnsafe(ch.data.Data())
+
+ // Read any FDs.
+ for i := 0; i < int(header.numFDs); i++ {
+ fd, err := ch.fdChan.RecvFDNonblock()
+ if err != nil {
+ log.Warningf("expected %d FDs, received %d successfully, got err after that: %v", header.numFDs, i, err)
+ break
+ }
+ ch.TrackFD(fd)
+ }
+
+ return header.message, dataLen - chanHeaderLen, nil
+}
diff --git a/pkg/lisafs/client.go b/pkg/lisafs/client.go
new file mode 100644
index 000000000..ccf1b9f72
--- /dev/null
+++ b/pkg/lisafs/client.go
@@ -0,0 +1,432 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "fmt"
+ "math"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+const (
+ // fdsToCloseBatchSize is the number of closed FDs batched before an Close
+ // RPC is made to close them all. fdsToCloseBatchSize is immutable.
+ fdsToCloseBatchSize = 100
+)
+
+// Client helps manage a connection to the lisafs server and pass messages
+// efficiently. There is a 1:1 mapping between a Connection and a Client.
+type Client struct {
+ // sockComm is the main socket by which this connections is established.
+ // Communication over the socket is synchronized by sockMu.
+ sockMu sync.Mutex
+ sockComm *sockCommunicator
+
+ // channelsMu protects channels and availableChannels.
+ channelsMu sync.Mutex
+ // channels tracks all the channels.
+ channels []*channel
+ // availableChannels is a LIFO (stack) of channels available to be used.
+ availableChannels []*channel
+ // activeWg represents active channels.
+ activeWg sync.WaitGroup
+
+ // watchdogWg only holds the watchdog goroutine.
+ watchdogWg sync.WaitGroup
+
+ // supported caches information about which messages are supported. It is
+ // indexed by MID. An MID is supported if supported[MID] is true.
+ supported []bool
+
+ // maxMessageSize is the maximum payload length (in bytes) that can be sent.
+ // It is initialized on Mount and is immutable.
+ maxMessageSize uint32
+
+ // fdsToClose tracks the FDs to close. It caches the FDs no longer being used
+ // by the client and closes them in one shot. It is not preserved across
+ // checkpoint/restore as FDIDs are not preserved.
+ fdsMu sync.Mutex
+ fdsToClose []FDID
+}
+
+// NewClient creates a new client for communication with the server. It mounts
+// the server and creates channels for fast IPC. NewClient takes ownership over
+// the passed socket. On success, it returns the initialized client along with
+// the root Inode.
+func NewClient(sock *unet.Socket, mountPath string) (*Client, *Inode, error) {
+ maxChans := maxChannels()
+ c := &Client{
+ sockComm: newSockComm(sock),
+ channels: make([]*channel, 0, maxChans),
+ availableChannels: make([]*channel, 0, maxChans),
+ maxMessageSize: 1 << 20, // 1 MB for now.
+ fdsToClose: make([]FDID, 0, fdsToCloseBatchSize),
+ }
+
+ // Start a goroutine to check socket health. This goroutine is also
+ // responsible for client cleanup.
+ c.watchdogWg.Add(1)
+ go c.watchdog()
+
+ // Clean everything up if anything fails.
+ cu := cleanup.Make(func() {
+ c.Close()
+ })
+ defer cu.Clean()
+
+ // Mount the server first. Assume Mount is supported so that we can make the
+ // Mount RPC below.
+ c.supported = make([]bool, Mount+1)
+ c.supported[Mount] = true
+ mountMsg := MountReq{
+ MountPath: SizedString(mountPath),
+ }
+ var mountResp MountResp
+ if err := c.SndRcvMessage(Mount, uint32(mountMsg.SizeBytes()), mountMsg.MarshalBytes, mountResp.UnmarshalBytes, nil); err != nil {
+ return nil, nil, err
+ }
+
+ // Initialize client.
+ c.maxMessageSize = uint32(mountResp.MaxMessageSize)
+ var maxSuppMID MID
+ for _, suppMID := range mountResp.SupportedMs {
+ if suppMID > maxSuppMID {
+ maxSuppMID = suppMID
+ }
+ }
+ c.supported = make([]bool, maxSuppMID+1)
+ for _, suppMID := range mountResp.SupportedMs {
+ c.supported[suppMID] = true
+ }
+
+ // Create channels parallely so that channels can be used to create more
+ // channels and costly initialization like flipcall.Endpoint.Connect can
+ // proceed parallely.
+ var channelsWg sync.WaitGroup
+ channelErrs := make([]error, maxChans)
+ for i := 0; i < maxChans; i++ {
+ channelsWg.Add(1)
+ curChanID := i
+ go func() {
+ defer channelsWg.Done()
+ ch, err := c.createChannel()
+ if err != nil {
+ log.Warningf("channel creation failed: %v", err)
+ channelErrs[curChanID] = err
+ return
+ }
+ c.channelsMu.Lock()
+ c.channels = append(c.channels, ch)
+ c.availableChannels = append(c.availableChannels, ch)
+ c.channelsMu.Unlock()
+ }()
+ }
+ channelsWg.Wait()
+
+ for _, channelErr := range channelErrs {
+ // Return the first non-nil channel creation error.
+ if channelErr != nil {
+ return nil, nil, channelErr
+ }
+ }
+ cu.Release()
+
+ return c, &mountResp.Root, nil
+}
+
+func (c *Client) watchdog() {
+ defer c.watchdogWg.Done()
+
+ events := []unix.PollFd{
+ {
+ Fd: int32(c.sockComm.FD()),
+ Events: unix.POLLHUP | unix.POLLRDHUP,
+ },
+ }
+
+ // Wait for a shutdown event.
+ for {
+ n, err := unix.Ppoll(events, nil, nil)
+ if err == unix.EINTR || err == unix.EAGAIN {
+ continue
+ }
+ if err != nil {
+ log.Warningf("lisafs.Client.watch(): %v", err)
+ } else if n != 1 {
+ log.Warningf("lisafs.Client.watch(): got %d events, wanted 1", n)
+ }
+ break
+ }
+
+ // Shutdown all active channels and wait for them to complete.
+ c.shutdownActiveChans()
+ c.activeWg.Wait()
+
+ // Close all channels.
+ c.channelsMu.Lock()
+ for _, ch := range c.channels {
+ ch.destroy()
+ }
+ c.channelsMu.Unlock()
+
+ // Close main socket.
+ c.sockComm.destroy()
+}
+
+func (c *Client) shutdownActiveChans() {
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+
+ availableChans := make(map[*channel]bool)
+ for _, ch := range c.availableChannels {
+ availableChans[ch] = true
+ }
+ for _, ch := range c.channels {
+ // A channel that is not available is active.
+ if _, ok := availableChans[ch]; !ok {
+ log.Debugf("shutting down active channel@%p...", ch)
+ ch.shutdown()
+ }
+ }
+
+ // Prevent channels from becoming available and serving new requests.
+ c.availableChannels = nil
+}
+
+// Close shuts down the main socket and waits for the watchdog to clean up.
+func (c *Client) Close() {
+ // This shutdown has no effect if the watchdog has already fired and closed
+ // the main socket.
+ c.sockComm.shutdown()
+ c.watchdogWg.Wait()
+}
+
+func (c *Client) createChannel() (*channel, error) {
+ var chanResp ChannelResp
+ var fds [2]int
+ if err := c.SndRcvMessage(Channel, 0, NoopMarshal, chanResp.UnmarshalUnsafe, fds[:]); err != nil {
+ return nil, err
+ }
+ if fds[0] < 0 || fds[1] < 0 {
+ closeFDs(fds[:])
+ return nil, fmt.Errorf("insufficient FDs provided in Channel response: %v", fds)
+ }
+
+ // Lets create the channel.
+ defer closeFDs(fds[:1]) // The data FD is not needed after this.
+ desc := flipcall.PacketWindowDescriptor{
+ FD: fds[0],
+ Offset: chanResp.dataOffset,
+ Length: int(chanResp.dataLength),
+ }
+
+ ch := &channel{}
+ if err := ch.data.Init(flipcall.ClientSide, desc); err != nil {
+ closeFDs(fds[1:])
+ return nil, err
+ }
+ ch.fdChan.Init(fds[1]) // fdChan now owns this FD.
+
+ // Only a connected channel is usable.
+ if err := ch.data.Connect(); err != nil {
+ ch.destroy()
+ return nil, err
+ }
+ return ch, nil
+}
+
+// IsSupported returns true if this connection supports the passed message.
+func (c *Client) IsSupported(m MID) bool {
+ return int(m) < len(c.supported) && c.supported[m]
+}
+
+// CloseFDBatched either queues the passed FD to be closed or makes a batch
+// RPC to close all the accumulated FDs-to-close.
+func (c *Client) CloseFDBatched(ctx context.Context, fd FDID) {
+ c.fdsMu.Lock()
+ c.fdsToClose = append(c.fdsToClose, fd)
+ if len(c.fdsToClose) < fdsToCloseBatchSize {
+ c.fdsMu.Unlock()
+ return
+ }
+
+ // Flush the cache. We should not hold fdsMu while making an RPC, so be sure
+ // to copy the fdsToClose to another buffer before unlocking fdsMu.
+ var toCloseArr [fdsToCloseBatchSize]FDID
+ toClose := toCloseArr[:len(c.fdsToClose)]
+ copy(toClose, c.fdsToClose)
+
+ // Clear fdsToClose so other FDIDs can be appended.
+ c.fdsToClose = c.fdsToClose[:0]
+ c.fdsMu.Unlock()
+
+ req := CloseReq{FDs: toClose}
+ ctx.UninterruptibleSleepStart(false)
+ err := c.SndRcvMessage(Close, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ log.Warningf("lisafs: batch closing FDs returned error: %v", err)
+ }
+}
+
+// SyncFDs makes a Fsync RPC to sync multiple FDs.
+func (c *Client) SyncFDs(ctx context.Context, fds []FDID) error {
+ if len(fds) == 0 {
+ return nil
+ }
+ req := FsyncReq{FDs: fds}
+ ctx.UninterruptibleSleepStart(false)
+ err := c.SndRcvMessage(FSync, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// SndRcvMessage invokes reqMarshal to marshal the request onto the payload
+// buffer, wakes up the server to process the request, waits for the response
+// and invokes respUnmarshal with the response payload. respFDs is populated
+// with the received FDs, extra fields are set to -1.
+//
+// Note that the function arguments intentionally accept marshal.Marshallable
+// functions like Marshal{Bytes/Unsafe} and Unmarshal{Bytes/Unsafe} instead of
+// directly accepting the marshal.Marshallable interface. Even though just
+// accepting marshal.Marshallable is cleaner, it leads to a heap allocation
+// (even if that interface variable itself does not escape). In other words,
+// implicit conversion to an interface leads to an allocation.
+//
+// Precondition: reqMarshal and respUnmarshal must be non-nil.
+func (c *Client) SndRcvMessage(m MID, payloadLen uint32, reqMarshal func(dst []byte), respUnmarshal func(src []byte), respFDs []int) error {
+ if !c.IsSupported(m) {
+ return unix.EOPNOTSUPP
+ }
+ if payloadLen > c.maxMessageSize {
+ log.Warningf("message %d has message size = %d which is larger than client.maxMessageSize = %d", m, payloadLen, c.maxMessageSize)
+ return unix.EIO
+ }
+ wantFDs := len(respFDs)
+ if wantFDs > math.MaxUint8 {
+ log.Warningf("want too many FDs: %d", wantFDs)
+ return unix.EINVAL
+ }
+
+ // Acquire a communicator.
+ comm := c.acquireCommunicator()
+ defer c.releaseCommunicator(comm)
+
+ // Marshal the request into comm's payload buffer and make the RPC.
+ reqMarshal(comm.PayloadBuf(payloadLen))
+ respM, respPayloadLen, err := comm.SndRcvMessage(m, payloadLen, uint8(wantFDs))
+
+ // Handle FD donation.
+ rcvFDs := comm.ReleaseFDs()
+ if numRcvFDs := len(rcvFDs); numRcvFDs+wantFDs > 0 {
+ // releasedFDs is memory owned by comm which can not be returned to caller.
+ // Copy it into the caller's buffer.
+ numFDCopied := copy(respFDs, rcvFDs)
+ if numFDCopied < numRcvFDs {
+ log.Warningf("%d unexpected FDs were donated by the server, wanted", numRcvFDs-numFDCopied, wantFDs)
+ closeFDs(rcvFDs[numFDCopied:])
+ }
+ if numFDCopied < wantFDs {
+ for i := numFDCopied; i < wantFDs; i++ {
+ respFDs[i] = -1
+ }
+ }
+ }
+
+ // Error cases.
+ if err != nil {
+ closeFDs(respFDs)
+ return err
+ }
+ if respM == Error {
+ closeFDs(respFDs)
+ var resp ErrorResp
+ resp.UnmarshalUnsafe(comm.PayloadBuf(respPayloadLen))
+ return unix.Errno(resp.errno)
+ }
+ if respM != m {
+ closeFDs(respFDs)
+ log.Warningf("sent %d message but got %d in response", m, respM)
+ return unix.EINVAL
+ }
+
+ // Success. The payload must be unmarshalled *before* comm is released.
+ respUnmarshal(comm.PayloadBuf(respPayloadLen))
+ return nil
+}
+
+// Postcondition: releaseCommunicator() must be called on the returned value.
+func (c *Client) acquireCommunicator() Communicator {
+ // Prefer using channel over socket because:
+ // - Channel uses a shared memory region for passing messages. IO from shared
+ // memory is faster and does not involve making a syscall.
+ // - No intermediate buffer allocation needed. With a channel, the message
+ // can be directly pasted into the shared memory region.
+ if ch := c.getChannel(); ch != nil {
+ return ch
+ }
+
+ c.sockMu.Lock()
+ return c.sockComm
+}
+
+// Precondition: comm must have been acquired via acquireCommunicator().
+func (c *Client) releaseCommunicator(comm Communicator) {
+ switch t := comm.(type) {
+ case *sockCommunicator:
+ c.sockMu.Unlock() // +checklocksforce: locked in acquireCommunicator().
+ case *channel:
+ c.releaseChannel(t)
+ default:
+ panic(fmt.Sprintf("unknown communicator type %T", t))
+ }
+}
+
+// getChannel pops a channel from the available channels stack. The caller must
+// release the channel after use.
+func (c *Client) getChannel() *channel {
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+ if len(c.availableChannels) == 0 {
+ return nil
+ }
+
+ idx := len(c.availableChannels) - 1
+ ch := c.availableChannels[idx]
+ c.availableChannels = c.availableChannels[:idx]
+ c.activeWg.Add(1)
+ return ch
+}
+
+// releaseChannel pushes the passed channel onto the available channel stack if
+// reinsert is true.
+func (c *Client) releaseChannel(ch *channel) {
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+
+ // If availableChannels is nil, then watchdog has fired and the client is
+ // shutting down. So don't make this channel available again.
+ if !ch.dead && c.availableChannels != nil {
+ c.availableChannels = append(c.availableChannels, ch)
+ }
+ c.activeWg.Done()
+}
diff --git a/pkg/lisafs/client_file.go b/pkg/lisafs/client_file.go
new file mode 100644
index 000000000..170c15705
--- /dev/null
+++ b/pkg/lisafs/client_file.go
@@ -0,0 +1,528 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "fmt"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
+)
+
+// ClientFD is a wrapper around FDID that provides client-side utilities
+// so that RPC making is easier.
+type ClientFD struct {
+ fd FDID
+ client *Client
+}
+
+// ID returns the underlying FDID.
+func (f *ClientFD) ID() FDID {
+ return f.fd
+}
+
+// Client returns the backing Client.
+func (f *ClientFD) Client() *Client {
+ return f.client
+}
+
+// NewFD initializes a new ClientFD.
+func (c *Client) NewFD(fd FDID) ClientFD {
+ return ClientFD{
+ client: c,
+ fd: fd,
+ }
+}
+
+// Ok returns true if the underlying FD is ok.
+func (f *ClientFD) Ok() bool {
+ return f.fd.Ok()
+}
+
+// CloseBatched queues this FD to be closed on the server and resets f.fd.
+// This maybe invoke the Close RPC if the queue is full.
+func (f *ClientFD) CloseBatched(ctx context.Context) {
+ f.client.CloseFDBatched(ctx, f.fd)
+ f.fd = InvalidFDID
+}
+
+// Close closes this FD immediately (invoking a Close RPC). Consider using
+// CloseBatched if closing this FD on remote right away is not critical.
+func (f *ClientFD) Close(ctx context.Context) error {
+ fdArr := [1]FDID{f.fd}
+ req := CloseReq{FDs: fdArr[:]}
+
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(Close, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// OpenAt makes the OpenAt RPC.
+func (f *ClientFD) OpenAt(ctx context.Context, flags uint32) (FDID, int, error) {
+ req := OpenAtReq{
+ FD: f.fd,
+ Flags: flags,
+ }
+ var respFD [1]int
+ var resp OpenAtResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(OpenAt, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalUnsafe, respFD[:])
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.NewFD, respFD[0], err
+}
+
+// OpenCreateAt makes the OpenCreateAt RPC.
+func (f *ClientFD) OpenCreateAt(ctx context.Context, name string, flags uint32, mode linux.FileMode, uid UID, gid GID) (Inode, FDID, int, error) {
+ var req OpenCreateAtReq
+ req.DirFD = f.fd
+ req.Name = SizedString(name)
+ req.Flags = primitive.Uint32(flags)
+ req.Mode = mode
+ req.UID = uid
+ req.GID = gid
+
+ var respFD [1]int
+ var resp OpenCreateAtResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(OpenCreateAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, respFD[:])
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.Child, resp.NewFD, respFD[0], err
+}
+
+// StatTo makes the Fstat RPC and populates stat with the result.
+func (f *ClientFD) StatTo(ctx context.Context, stat *linux.Statx) error {
+ req := StatReq{FD: f.fd}
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FStat, uint32(req.SizeBytes()), req.MarshalUnsafe, stat.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// Sync makes the Fsync RPC.
+func (f *ClientFD) Sync(ctx context.Context) error {
+ req := FsyncReq{FDs: []FDID{f.fd}}
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FSync, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// chunkify applies fn to buf in chunks based on chunkSize.
+func chunkify(chunkSize uint64, buf []byte, fn func([]byte, uint64) (uint64, error)) (uint64, error) {
+ toProcess := uint64(len(buf))
+ var (
+ totalProcessed uint64
+ curProcessed uint64
+ off uint64
+ err error
+ )
+ for {
+ if totalProcessed == toProcess {
+ return totalProcessed, nil
+ }
+
+ if totalProcessed+chunkSize > toProcess {
+ curProcessed, err = fn(buf[totalProcessed:], off)
+ } else {
+ curProcessed, err = fn(buf[totalProcessed:totalProcessed+chunkSize], off)
+ }
+ totalProcessed += curProcessed
+ off += curProcessed
+
+ if err != nil {
+ return totalProcessed, err
+ }
+
+ // Return partial result immediately.
+ if curProcessed < chunkSize {
+ return totalProcessed, nil
+ }
+
+ // If we received more bytes than we ever requested, this is a problem.
+ if totalProcessed > toProcess {
+ panic(fmt.Sprintf("bytes completed (%d)) > requested (%d)", totalProcessed, toProcess))
+ }
+ }
+}
+
+// Read makes the PRead RPC.
+func (f *ClientFD) Read(ctx context.Context, dst []byte, offset uint64) (uint64, error) {
+ var resp PReadResp
+ // maxDataReadSize represents the maximum amount of data we can read at once
+ // (maximum message size - metadata size present in resp). Uninitialized
+ // resp.SizeBytes() correctly returns the metadata size only (since the read
+ // buffer is empty).
+ maxDataReadSize := uint64(f.client.maxMessageSize) - uint64(resp.SizeBytes())
+ return chunkify(maxDataReadSize, dst, func(buf []byte, curOff uint64) (uint64, error) {
+ req := PReadReq{
+ Offset: offset + curOff,
+ FD: f.fd,
+ Count: uint32(len(buf)),
+ }
+
+ // This will be unmarshalled into. Already set Buf so that we don't need to
+ // allocate a temporary buffer during unmarshalling.
+ // PReadResp.UnmarshalBytes expects this to be set.
+ resp.Buf = buf
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(PRead, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return uint64(resp.NumBytes), err
+ })
+}
+
+// Write makes the PWrite RPC.
+func (f *ClientFD) Write(ctx context.Context, src []byte, offset uint64) (uint64, error) {
+ var req PWriteReq
+ // maxDataWriteSize represents the maximum amount of data we can write at
+ // once (maximum message size - metadata size present in req). Uninitialized
+ // req.SizeBytes() correctly returns the metadata size only (since the write
+ // buffer is empty).
+ maxDataWriteSize := uint64(f.client.maxMessageSize) - uint64(req.SizeBytes())
+ return chunkify(maxDataWriteSize, src, func(buf []byte, curOff uint64) (uint64, error) {
+ req = PWriteReq{
+ Offset: primitive.Uint64(offset + curOff),
+ FD: f.fd,
+ NumBytes: primitive.Uint32(len(buf)),
+ Buf: buf,
+ }
+
+ var resp PWriteResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(PWrite, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.Count, err
+ })
+}
+
+// MkdirAt makes the MkdirAt RPC.
+func (f *ClientFD) MkdirAt(ctx context.Context, name string, mode linux.FileMode, uid UID, gid GID) (*Inode, error) {
+ var req MkdirAtReq
+ req.DirFD = f.fd
+ req.Name = SizedString(name)
+ req.Mode = mode
+ req.UID = uid
+ req.GID = gid
+
+ var resp MkdirAtResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(MkdirAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return &resp.ChildDir, err
+}
+
+// SymlinkAt makes the SymlinkAt RPC.
+func (f *ClientFD) SymlinkAt(ctx context.Context, name, target string, uid UID, gid GID) (*Inode, error) {
+ req := SymlinkAtReq{
+ DirFD: f.fd,
+ Name: SizedString(name),
+ Target: SizedString(target),
+ UID: uid,
+ GID: gid,
+ }
+
+ var resp SymlinkAtResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(SymlinkAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return &resp.Symlink, err
+}
+
+// LinkAt makes the LinkAt RPC.
+func (f *ClientFD) LinkAt(ctx context.Context, targetFD FDID, name string) (*Inode, error) {
+ req := LinkAtReq{
+ DirFD: f.fd,
+ Target: targetFD,
+ Name: SizedString(name),
+ }
+
+ var resp LinkAtResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(LinkAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return &resp.Link, err
+}
+
+// MknodAt makes the MknodAt RPC.
+func (f *ClientFD) MknodAt(ctx context.Context, name string, mode linux.FileMode, uid UID, gid GID, minor, major uint32) (*Inode, error) {
+ var req MknodAtReq
+ req.DirFD = f.fd
+ req.Name = SizedString(name)
+ req.Mode = mode
+ req.UID = uid
+ req.GID = gid
+ req.Minor = primitive.Uint32(minor)
+ req.Major = primitive.Uint32(major)
+
+ var resp MknodAtResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(MknodAt, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return &resp.Child, err
+}
+
+// SetStat makes the SetStat RPC.
+func (f *ClientFD) SetStat(ctx context.Context, stat *linux.Statx) (uint32, error, error) {
+ req := SetStatReq{
+ FD: f.fd,
+ Mask: stat.Mask,
+ Mode: uint32(stat.Mode),
+ UID: UID(stat.UID),
+ GID: GID(stat.GID),
+ Size: stat.Size,
+ Atime: linux.Timespec{
+ Sec: stat.Atime.Sec,
+ Nsec: int64(stat.Atime.Nsec),
+ },
+ Mtime: linux.Timespec{
+ Sec: stat.Mtime.Sec,
+ Nsec: int64(stat.Mtime.Nsec),
+ },
+ }
+
+ var resp SetStatResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(SetStat, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.FailureMask, unix.Errno(resp.FailureErrNo), err
+}
+
+// WalkMultiple makes the Walk RPC with multiple path components.
+func (f *ClientFD) WalkMultiple(ctx context.Context, names []string) (WalkStatus, []Inode, error) {
+ req := WalkReq{
+ DirFD: f.fd,
+ Path: StringArray(names),
+ }
+
+ var resp WalkResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(Walk, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.Status, resp.Inodes, err
+}
+
+// Walk makes the Walk RPC with just one path component to walk.
+func (f *ClientFD) Walk(ctx context.Context, name string) (*Inode, error) {
+ req := WalkReq{
+ DirFD: f.fd,
+ Path: []string{name},
+ }
+
+ var inode [1]Inode
+ resp := WalkResp{Inodes: inode[:]}
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(Walk, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ return nil, err
+ }
+
+ switch resp.Status {
+ case WalkComponentDoesNotExist:
+ return nil, unix.ENOENT
+ case WalkComponentSymlink:
+ // f is not a directory which can be walked on.
+ return nil, unix.ENOTDIR
+ }
+
+ if n := len(resp.Inodes); n > 1 {
+ for i := range resp.Inodes {
+ f.client.CloseFDBatched(ctx, resp.Inodes[i].ControlFD)
+ }
+ log.Warningf("requested to walk one component, but got %d results", n)
+ return nil, unix.EIO
+ } else if n == 0 {
+ log.Warningf("walk has success status but no results returned")
+ return nil, unix.ENOENT
+ }
+ return &inode[0], err
+}
+
+// WalkStat makes the WalkStat RPC with multiple path components to walk.
+func (f *ClientFD) WalkStat(ctx context.Context, names []string) ([]linux.Statx, error) {
+ req := WalkReq{
+ DirFD: f.fd,
+ Path: StringArray(names),
+ }
+
+ var resp WalkStatResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(WalkStat, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.Stats, err
+}
+
+// StatFSTo makes the FStatFS RPC and populates statFS with the result.
+func (f *ClientFD) StatFSTo(ctx context.Context, statFS *StatFS) error {
+ req := FStatFSReq{FD: f.fd}
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FStatFS, uint32(req.SizeBytes()), req.MarshalUnsafe, statFS.UnmarshalUnsafe, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// Allocate makes the FAllocate RPC.
+func (f *ClientFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ req := FAllocateReq{
+ FD: f.fd,
+ Mode: mode,
+ Offset: offset,
+ Length: length,
+ }
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FAllocate, uint32(req.SizeBytes()), req.MarshalUnsafe, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// ReadLinkAt makes the ReadLinkAt RPC.
+func (f *ClientFD) ReadLinkAt(ctx context.Context) (string, error) {
+ req := ReadLinkAtReq{FD: f.fd}
+ var resp ReadLinkAtResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(ReadLinkAt, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return string(resp.Target), err
+}
+
+// Flush makes the Flush RPC.
+func (f *ClientFD) Flush(ctx context.Context) error {
+ if !f.client.IsSupported(Flush) {
+ // If Flush is not supported, it probably means that it would be a noop.
+ return nil
+ }
+ req := FlushReq{FD: f.fd}
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(Flush, uint32(req.SizeBytes()), req.MarshalUnsafe, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// Connect makes the Connect RPC.
+func (f *ClientFD) Connect(ctx context.Context, sockType linux.SockType) (int, error) {
+ req := ConnectReq{FD: f.fd, SockType: uint32(sockType)}
+ var sockFD [1]int
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(Connect, uint32(req.SizeBytes()), req.MarshalUnsafe, NoopUnmarshal, sockFD[:])
+ ctx.UninterruptibleSleepFinish(false)
+ if err == nil && sockFD[0] < 0 {
+ err = unix.EBADF
+ }
+ return sockFD[0], err
+}
+
+// UnlinkAt makes the UnlinkAt RPC.
+func (f *ClientFD) UnlinkAt(ctx context.Context, name string, flags uint32) error {
+ req := UnlinkAtReq{
+ DirFD: f.fd,
+ Name: SizedString(name),
+ Flags: primitive.Uint32(flags),
+ }
+
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(UnlinkAt, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// RenameTo makes the RenameAt RPC which renames f to newDirFD directory with
+// name newName.
+func (f *ClientFD) RenameTo(ctx context.Context, newDirFD FDID, newName string) error {
+ req := RenameAtReq{
+ Renamed: f.fd,
+ NewDir: newDirFD,
+ NewName: SizedString(newName),
+ }
+
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(RenameAt, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// Getdents64 makes the Getdents64 RPC.
+func (f *ClientFD) Getdents64(ctx context.Context, count int32) ([]Dirent64, error) {
+ req := Getdents64Req{
+ DirFD: f.fd,
+ Count: count,
+ }
+
+ var resp Getdents64Resp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(Getdents64, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.Dirents, err
+}
+
+// ListXattr makes the FListXattr RPC.
+func (f *ClientFD) ListXattr(ctx context.Context, size uint64) ([]string, error) {
+ req := FListXattrReq{
+ FD: f.fd,
+ Size: size,
+ }
+
+ var resp FListXattrResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FListXattr, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return resp.Xattrs, err
+}
+
+// GetXattr makes the FGetXattr RPC.
+func (f *ClientFD) GetXattr(ctx context.Context, name string, size uint64) (string, error) {
+ req := FGetXattrReq{
+ FD: f.fd,
+ Name: SizedString(name),
+ BufSize: primitive.Uint32(size),
+ }
+
+ var resp FGetXattrResp
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FGetXattr, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return string(resp.Value), err
+}
+
+// SetXattr makes the FSetXattr RPC.
+func (f *ClientFD) SetXattr(ctx context.Context, name string, value string, flags uint32) error {
+ req := FSetXattrReq{
+ FD: f.fd,
+ Name: SizedString(name),
+ Value: SizedString(value),
+ Flags: primitive.Uint32(flags),
+ }
+
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FSetXattr, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+// RemoveXattr makes the FRemoveXattr RPC.
+func (f *ClientFD) RemoveXattr(ctx context.Context, name string) error {
+ req := FRemoveXattrReq{
+ FD: f.fd,
+ Name: SizedString(name),
+ }
+
+ ctx.UninterruptibleSleepStart(false)
+ err := f.client.SndRcvMessage(FRemoveXattr, uint32(req.SizeBytes()), req.MarshalBytes, NoopUnmarshal, nil)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
diff --git a/pkg/lisafs/communicator.go b/pkg/lisafs/communicator.go
new file mode 100644
index 000000000..ec2035158
--- /dev/null
+++ b/pkg/lisafs/communicator.go
@@ -0,0 +1,80 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import "golang.org/x/sys/unix"
+
+// Communicator is a server side utility which represents exactly how the
+// server is communicating with the client.
+type Communicator interface {
+ // PayloadBuf returns a slice to the payload section of its internal buffer
+ // where the message can be marshalled. The handlers should use this to
+ // populate the payload buffer with the message.
+ //
+ // The payload buffer contents *should* be preserved across calls with
+ // different sizes. Note that this is not a guarantee, because a compromised
+ // owner of a "shared" payload buffer can tamper with its contents anytime,
+ // even when it's not its turn to do so.
+ PayloadBuf(size uint32) []byte
+
+ // SndRcvMessage sends message m. The caller must have populated PayloadBuf()
+ // with payloadLen bytes. The caller expects to receive wantFDs FDs.
+ // Any received FDs must be accessible via ReleaseFDs(). It returns the
+ // response message along with the response payload length.
+ SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error)
+
+ // DonateFD makes fd non-blocking and starts tracking it. The next call to
+ // ReleaseFDs will include fd in the order it was added. Communicator takes
+ // ownership of fd. Server side should call this.
+ DonateFD(fd int) error
+
+ // Track starts tracking fd. The next call to ReleaseFDs will include fd in
+ // the order it was added. Communicator takes ownership of fd. Client side
+ // should use this for accumulating received FDs.
+ TrackFD(fd int)
+
+ // ReleaseFDs returns the accumulated FDs and stops tracking them. The
+ // ownership of the FDs is transferred to the caller.
+ ReleaseFDs() []int
+}
+
+// fdTracker is a partial implementation of Communicator. It can be embedded in
+// Communicator implementations to keep track of FD donations.
+type fdTracker struct {
+ fds []int
+}
+
+// DonateFD implements Communicator.DonateFD.
+func (d *fdTracker) DonateFD(fd int) error {
+ // Make sure the FD is non-blocking.
+ if err := unix.SetNonblock(fd, true); err != nil {
+ unix.Close(fd)
+ return err
+ }
+ d.TrackFD(fd)
+ return nil
+}
+
+// TrackFD implements Communicator.TrackFD.
+func (d *fdTracker) TrackFD(fd int) {
+ d.fds = append(d.fds, fd)
+}
+
+// ReleaseFDs implements Communicator.ReleaseFDs.
+func (d *fdTracker) ReleaseFDs() []int {
+ ret := d.fds
+ d.fds = d.fds[:0]
+ return ret
+}
diff --git a/pkg/lisafs/connection.go b/pkg/lisafs/connection.go
new file mode 100644
index 000000000..f6e5ecb4f
--- /dev/null
+++ b/pkg/lisafs/connection.go
@@ -0,0 +1,320 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// Connection represents a connection between a mount point in the client and a
+// mount point in the server. It is owned by the server on which it was started
+// and facilitates communication with the client mount.
+//
+// Each connection is set up using a unix domain socket. One end is owned by
+// the server and the other end is owned by the client. The connection may
+// spawn additional comunicational channels for the same mount for increased
+// RPC concurrency.
+type Connection struct {
+ // server is the server on which this connection was created. It is immutably
+ // associated with it for its entire lifetime.
+ server *Server
+
+ // mounted is a one way flag indicating whether this connection has been
+ // mounted correctly and the server is initialized properly.
+ mounted bool
+
+ // readonly indicates if this connection is readonly. All write operations
+ // will fail with EROFS.
+ readonly bool
+
+ // sockComm is the main socket by which this connections is established.
+ sockComm *sockCommunicator
+
+ // channelsMu protects channels.
+ channelsMu sync.Mutex
+ // channels keeps track of all open channels.
+ channels []*channel
+
+ // activeWg represents active channels.
+ activeWg sync.WaitGroup
+
+ // reqGate counts requests that are still being handled.
+ reqGate sync.Gate
+
+ // channelAlloc is used to allocate memory for channels.
+ channelAlloc *flipcall.PacketWindowAllocator
+
+ fdsMu sync.RWMutex
+ // fds keeps tracks of open FDs on this server. It is protected by fdsMu.
+ fds map[FDID]genericFD
+ // nextFDID is the next available FDID. It is protected by fdsMu.
+ nextFDID FDID
+}
+
+// CreateConnection initializes a new connection - creating a server if
+// required. The connection must be started separately.
+func (s *Server) CreateConnection(sock *unet.Socket, readonly bool) (*Connection, error) {
+ c := &Connection{
+ sockComm: newSockComm(sock),
+ server: s,
+ readonly: readonly,
+ channels: make([]*channel, 0, maxChannels()),
+ fds: make(map[FDID]genericFD),
+ nextFDID: InvalidFDID + 1,
+ }
+
+ alloc, err := flipcall.NewPacketWindowAllocator()
+ if err != nil {
+ return nil, err
+ }
+ c.channelAlloc = alloc
+ return c, nil
+}
+
+// Server returns the associated server.
+func (c *Connection) Server() *Server {
+ return c.server
+}
+
+// ServerImpl returns the associated server implementation.
+func (c *Connection) ServerImpl() ServerImpl {
+ return c.server.impl
+}
+
+// Run defines the lifecycle of a connection.
+func (c *Connection) Run() {
+ defer c.close()
+
+ // Start handling requests on this connection.
+ for {
+ m, payloadLen, err := c.sockComm.rcvMsg(0 /* wantFDs */)
+ if err != nil {
+ log.Debugf("sock read failed, closing connection: %v", err)
+ return
+ }
+
+ respM, respPayloadLen, respFDs := c.handleMsg(c.sockComm, m, payloadLen)
+ err = c.sockComm.sndPrepopulatedMsg(respM, respPayloadLen, respFDs)
+ closeFDs(respFDs)
+ if err != nil {
+ log.Debugf("sock write failed, closing connection: %v", err)
+ return
+ }
+ }
+}
+
+// service starts servicing the passed channel until the channel is shutdown.
+// This is a blocking method and hence must be called in a separate goroutine.
+func (c *Connection) service(ch *channel) error {
+ rcvDataLen, err := ch.data.RecvFirst()
+ if err != nil {
+ return err
+ }
+ for rcvDataLen > 0 {
+ m, payloadLen, err := ch.rcvMsg(rcvDataLen)
+ if err != nil {
+ return err
+ }
+ respM, respPayloadLen, respFDs := c.handleMsg(ch, m, payloadLen)
+ numFDs := ch.sendFDs(respFDs)
+ closeFDs(respFDs)
+
+ ch.marshalHdr(respM, numFDs)
+ rcvDataLen, err = ch.data.SendRecv(respPayloadLen + chanHeaderLen)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (c *Connection) respondError(comm Communicator, err unix.Errno) (MID, uint32, []int) {
+ resp := &ErrorResp{errno: uint32(err)}
+ respLen := uint32(resp.SizeBytes())
+ resp.MarshalUnsafe(comm.PayloadBuf(respLen))
+ return Error, respLen, nil
+}
+
+func (c *Connection) handleMsg(comm Communicator, m MID, payloadLen uint32) (MID, uint32, []int) {
+ if !c.reqGate.Enter() {
+ // c.close() has been called; the connection is shutting down.
+ return c.respondError(comm, unix.ECONNRESET)
+ }
+ defer c.reqGate.Leave()
+
+ if !c.mounted && m != Mount {
+ log.Warningf("connection must first be mounted")
+ return c.respondError(comm, unix.EINVAL)
+ }
+
+ // Check if the message is supported for forward compatibility.
+ if int(m) >= len(c.server.handlers) || c.server.handlers[m] == nil {
+ log.Warningf("received request which is not supported by the server, MID = %d", m)
+ return c.respondError(comm, unix.EOPNOTSUPP)
+ }
+
+ // Try handling the request.
+ respPayloadLen, err := c.server.handlers[m](c, comm, payloadLen)
+ fds := comm.ReleaseFDs()
+ if err != nil {
+ closeFDs(fds)
+ return c.respondError(comm, p9.ExtractErrno(err))
+ }
+
+ return m, respPayloadLen, fds
+}
+
+func (c *Connection) close() {
+ // Wait for completion of all inflight requests. This is mostly so that if
+ // a request is stuck, the sandbox supervisor has the opportunity to kill
+ // us with SIGABRT to get a stack dump of the offending handler.
+ c.reqGate.Close()
+
+ // Shutdown and clean up channels.
+ c.channelsMu.Lock()
+ for _, ch := range c.channels {
+ ch.shutdown()
+ }
+ c.activeWg.Wait()
+ for _, ch := range c.channels {
+ ch.destroy()
+ }
+ // This is to prevent additional channels from being created.
+ c.channels = nil
+ c.channelsMu.Unlock()
+
+ // Free the channel memory.
+ if c.channelAlloc != nil {
+ c.channelAlloc.Destroy()
+ }
+
+ // Ensure the connection is closed.
+ c.sockComm.destroy()
+
+ // Cleanup all FDs.
+ c.fdsMu.Lock()
+ for fdid := range c.fds {
+ fd := c.removeFDLocked(fdid)
+ fd.DecRef(nil) // Drop the ref held by c.
+ }
+ c.fdsMu.Unlock()
+}
+
+// The caller gains a ref on the FD on success.
+func (c *Connection) lookupFD(id FDID) (genericFD, error) {
+ c.fdsMu.RLock()
+ defer c.fdsMu.RUnlock()
+
+ fd, ok := c.fds[id]
+ if !ok {
+ return nil, unix.EBADF
+ }
+ fd.IncRef()
+ return fd, nil
+}
+
+// LookupControlFD retrieves the control FD identified by id on this
+// connection. On success, the caller gains a ref on the FD.
+func (c *Connection) LookupControlFD(id FDID) (*ControlFD, error) {
+ fd, err := c.lookupFD(id)
+ if err != nil {
+ return nil, err
+ }
+
+ cfd, ok := fd.(*ControlFD)
+ if !ok {
+ fd.DecRef(nil)
+ return nil, unix.EINVAL
+ }
+ return cfd, nil
+}
+
+// LookupOpenFD retrieves the open FD identified by id on this
+// connection. On success, the caller gains a ref on the FD.
+func (c *Connection) LookupOpenFD(id FDID) (*OpenFD, error) {
+ fd, err := c.lookupFD(id)
+ if err != nil {
+ return nil, err
+ }
+
+ ofd, ok := fd.(*OpenFD)
+ if !ok {
+ fd.DecRef(nil)
+ return nil, unix.EINVAL
+ }
+ return ofd, nil
+}
+
+// insertFD inserts the passed fd into the internal datastructure to track FDs.
+// The caller must hold a ref on fd which is transferred to the connection.
+func (c *Connection) insertFD(fd genericFD) FDID {
+ c.fdsMu.Lock()
+ defer c.fdsMu.Unlock()
+
+ res := c.nextFDID
+ c.nextFDID++
+ if c.nextFDID < res {
+ panic("ran out of FDIDs")
+ }
+ c.fds[res] = fd
+ return res
+}
+
+// RemoveFD makes c stop tracking the passed FDID and drops its ref on it.
+func (c *Connection) RemoveFD(id FDID) {
+ c.fdsMu.Lock()
+ fd := c.removeFDLocked(id)
+ c.fdsMu.Unlock()
+ if fd != nil {
+ // Drop the ref held by c. This can take arbitrarily long. So do not hold
+ // c.fdsMu while calling it.
+ fd.DecRef(nil)
+ }
+}
+
+// RemoveControlFDLocked is the same as RemoveFD with added preconditions.
+//
+// Preconditions:
+// * server's rename mutex must at least be read locked.
+// * id must be pointing to a control FD.
+func (c *Connection) RemoveControlFDLocked(id FDID) {
+ c.fdsMu.Lock()
+ fd := c.removeFDLocked(id)
+ c.fdsMu.Unlock()
+ if fd != nil {
+ // Drop the ref held by c. This can take arbitrarily long. So do not hold
+ // c.fdsMu while calling it.
+ fd.(*ControlFD).DecRefLocked()
+ }
+}
+
+// removeFDLocked makes c stop tracking the passed FDID. Note that the caller
+// must drop ref on the returned fd (preferably without holding c.fdsMu).
+//
+// Precondition: c.fdsMu is locked.
+func (c *Connection) removeFDLocked(id FDID) genericFD {
+ fd := c.fds[id]
+ if fd == nil {
+ log.Warningf("removeFDLocked called on non-existent FDID %d", id)
+ return nil
+ }
+ delete(c.fds, id)
+ return fd
+}
diff --git a/pkg/lisafs/connection_test.go b/pkg/lisafs/connection_test.go
new file mode 100644
index 000000000..28ba47112
--- /dev/null
+++ b/pkg/lisafs/connection_test.go
@@ -0,0 +1,194 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package connection_test
+
+import (
+ "reflect"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/lisafs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+const (
+ dynamicMsgID = lisafs.Channel + 1
+ versionMsgID = dynamicMsgID + 1
+)
+
+var handlers = [...]lisafs.RPCHandler{
+ lisafs.Error: lisafs.ErrorHandler,
+ lisafs.Mount: lisafs.MountHandler,
+ lisafs.Channel: lisafs.ChannelHandler,
+ dynamicMsgID: dynamicMsgHandler,
+ versionMsgID: versionHandler,
+}
+
+// testServer implements lisafs.ServerImpl.
+type testServer struct {
+ lisafs.Server
+}
+
+var _ lisafs.ServerImpl = (*testServer)(nil)
+
+type testControlFD struct {
+ lisafs.ControlFD
+ lisafs.ControlFDImpl
+}
+
+func (fd *testControlFD) FD() *lisafs.ControlFD {
+ return &fd.ControlFD
+}
+
+// Mount implements lisafs.Mount.
+func (s *testServer) Mount(c *lisafs.Connection, mountPath string) (lisafs.ControlFDImpl, lisafs.Inode, error) {
+ return &testControlFD{}, lisafs.Inode{ControlFD: 1}, nil
+}
+
+// MaxMessageSize implements lisafs.MaxMessageSize.
+func (s *testServer) MaxMessageSize() uint32 {
+ return lisafs.MaxMessageSize()
+}
+
+// SupportedMessages implements lisafs.ServerImpl.SupportedMessages.
+func (s *testServer) SupportedMessages() []lisafs.MID {
+ return []lisafs.MID{
+ lisafs.Mount,
+ lisafs.Channel,
+ dynamicMsgID,
+ versionMsgID,
+ }
+}
+
+func runServerClient(t testing.TB, clientFn func(c *lisafs.Client)) {
+ serverSocket, clientSocket, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+
+ ts := &testServer{}
+ ts.Server.InitTestOnly(ts, handlers[:])
+ conn, err := ts.CreateConnection(serverSocket, false /* readonly */)
+ if err != nil {
+ t.Fatalf("starting connection failed: %v", err)
+ return
+ }
+ ts.StartConnection(conn)
+
+ c, _, err := lisafs.NewClient(clientSocket, "/")
+ if err != nil {
+ t.Fatalf("client creation failed: %v", err)
+ }
+
+ clientFn(c)
+
+ c.Close() // This should trigger client and server shutdown.
+ ts.Wait()
+}
+
+// TestStartUp tests that the server and client can be started up correctly.
+func TestStartUp(t *testing.T) {
+ runServerClient(t, func(c *lisafs.Client) {
+ if c.IsSupported(lisafs.Error) {
+ t.Errorf("sending error messages should not be supported")
+ }
+ })
+}
+
+func TestUnsupportedMessage(t *testing.T) {
+ unsupportedM := lisafs.MID(len(handlers))
+ runServerClient(t, func(c *lisafs.Client) {
+ if err := c.SndRcvMessage(unsupportedM, 0, lisafs.NoopMarshal, lisafs.NoopUnmarshal, nil); err != unix.EOPNOTSUPP {
+ t.Errorf("expected EOPNOTSUPP but got err: %v", err)
+ }
+ })
+}
+
+func dynamicMsgHandler(c *lisafs.Connection, comm lisafs.Communicator, payloadLen uint32) (uint32, error) {
+ var req lisafs.MsgDynamic
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ // Just echo back the message.
+ respPayloadLen := uint32(req.SizeBytes())
+ req.MarshalBytes(comm.PayloadBuf(respPayloadLen))
+ return respPayloadLen, nil
+}
+
+// TestStress stress tests sending many messages from various goroutines.
+func TestStress(t *testing.T) {
+ runServerClient(t, func(c *lisafs.Client) {
+ concurrency := 8
+ numMsgPerGoroutine := 5000
+ var clientWg sync.WaitGroup
+ for i := 0; i < concurrency; i++ {
+ clientWg.Add(1)
+ go func() {
+ defer clientWg.Done()
+
+ for j := 0; j < numMsgPerGoroutine; j++ {
+ // Create a massive random message.
+ var req lisafs.MsgDynamic
+ req.Randomize(100)
+
+ var resp lisafs.MsgDynamic
+ if err := c.SndRcvMessage(dynamicMsgID, uint32(req.SizeBytes()), req.MarshalBytes, resp.UnmarshalBytes, nil); err != nil {
+ t.Errorf("SndRcvMessage: received unexpected error %v", err)
+ return
+ }
+ if !reflect.DeepEqual(&req, &resp) {
+ t.Errorf("response should be the same as request: request = %+v, response = %+v", req, resp)
+ }
+ }
+ }()
+ }
+
+ clientWg.Wait()
+ })
+}
+
+func versionHandler(c *lisafs.Connection, comm lisafs.Communicator, payloadLen uint32) (uint32, error) {
+ // To be fair, usually handlers will create their own objects and return a
+ // pointer to those. Might be tempting to reuse above variables, but don't.
+ var rv lisafs.P9Version
+ rv.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ // Create a new response.
+ sv := lisafs.P9Version{
+ MSize: rv.MSize,
+ Version: "9P2000.L.Google.11",
+ }
+ respPayloadLen := uint32(sv.SizeBytes())
+ sv.MarshalBytes(comm.PayloadBuf(respPayloadLen))
+ return respPayloadLen, nil
+}
+
+// BenchmarkSendRecv exists to compete against p9's BenchmarkSendRecvChannel.
+func BenchmarkSendRecv(b *testing.B) {
+ b.ReportAllocs()
+ sendV := lisafs.P9Version{
+ MSize: 1 << 20,
+ Version: "9P2000.L.Google.12",
+ }
+
+ var recvV lisafs.P9Version
+ runServerClient(b, func(c *lisafs.Client) {
+ for i := 0; i < b.N; i++ {
+ if err := c.SndRcvMessage(versionMsgID, uint32(sendV.SizeBytes()), sendV.MarshalBytes, recvV.UnmarshalBytes, nil); err != nil {
+ b.Fatalf("unexpected error occurred: %v", err)
+ }
+ }
+ })
+}
diff --git a/pkg/lisafs/fd.go b/pkg/lisafs/fd.go
new file mode 100644
index 000000000..cc6919a1b
--- /dev/null
+++ b/pkg/lisafs/fd.go
@@ -0,0 +1,374 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// FDID (file descriptor identifier) is used to identify FDs on a connection.
+// Each connection has its own FDID namespace.
+//
+// +marshal slice:FDIDSlice
+type FDID uint32
+
+// InvalidFDID represents an invalid FDID.
+const InvalidFDID FDID = 0
+
+// Ok returns true if f is a valid FDID.
+func (f FDID) Ok() bool {
+ return f != InvalidFDID
+}
+
+// genericFD can represent a ControlFD or OpenFD.
+type genericFD interface {
+ refsvfs2.RefCounter
+}
+
+// A ControlFD is the gateway to the backing filesystem tree node. It is an
+// unusual concept. This exists to provide a safe way to do path-based
+// operations on the file. It performs operations that can modify the
+// filesystem tree and synchronizes these operations. See ControlFDImpl for
+// supported operations.
+//
+// It is not an inode, because multiple control FDs are allowed to exist on the
+// same file. It is not a file descriptor because it is not tied to any access
+// mode, i.e. a control FD can change its access mode based on the operation
+// being performed.
+//
+// Reference Model:
+// * When a control FD is created, the connection takes a ref on it which
+// represents the client's ref on the FD.
+// * The client can drop its ref via the Close RPC which will in turn make the
+// connection drop its ref.
+// * Each control FD holds a ref on its parent for its entire life time.
+type ControlFD struct {
+ controlFDRefs
+ controlFDEntry
+
+ // parent is the parent directory FD containing the file this FD represents.
+ // A ControlFD holds a ref on parent for its entire lifetime. If this FD
+ // represents the root, then parent is nil. parent may be a control FD from
+ // another connection (another mount point). parent is protected by the
+ // backing server's rename mutex.
+ parent *ControlFD
+
+ // name is the file path's last component name. If this FD represents the
+ // root directory, then name is the mount path. name is protected by the
+ // backing server's rename mutex.
+ name string
+
+ // children is a linked list of all children control FDs. As per reference
+ // model, all children hold a ref on this FD.
+ // children is protected by childrenMu and server's rename mutex. To have
+ // mutual exclusion, it is sufficient to:
+ // * Hold rename mutex for reading and lock childrenMu. OR
+ // * Or hold rename mutex for writing.
+ childrenMu sync.Mutex
+ children controlFDList
+
+ // openFDs is a linked list of all FDs opened on this FD. As per reference
+ // model, all open FDs hold a ref on this FD.
+ openFDsMu sync.RWMutex
+ openFDs openFDList
+
+ // All the following fields are immutable.
+
+ // id is the unique FD identifier which identifies this FD on its connection.
+ id FDID
+
+ // conn is the backing connection owning this FD.
+ conn *Connection
+
+ // ftype is the file type of the backing inode. ftype.FileType() == ftype.
+ ftype linux.FileMode
+
+ // impl is the control FD implementation which embeds this struct. It
+ // contains all the implementation specific details.
+ impl ControlFDImpl
+}
+
+var _ genericFD = (*ControlFD)(nil)
+
+// DecRef implements refsvfs2.RefCounter.DecRef. Note that the context
+// parameter should never be used. It exists solely to comply with the
+// refsvfs2.RefCounter interface.
+func (fd *ControlFD) DecRef(context.Context) {
+ fd.controlFDRefs.DecRef(func() {
+ if fd.parent != nil {
+ fd.conn.server.RenameMu.RLock()
+ fd.parent.childrenMu.Lock()
+ fd.parent.children.Remove(fd)
+ fd.parent.childrenMu.Unlock()
+ fd.conn.server.RenameMu.RUnlock()
+ fd.parent.DecRef(nil) // Drop the ref on the parent.
+ }
+ fd.impl.Close(fd.conn)
+ })
+}
+
+// DecRefLocked is the same as DecRef except the added precondition.
+//
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) DecRefLocked() {
+ fd.controlFDRefs.DecRef(func() {
+ fd.clearParentLocked()
+ fd.impl.Close(fd.conn)
+ })
+}
+
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) clearParentLocked() {
+ if fd.parent == nil {
+ return
+ }
+ fd.parent.childrenMu.Lock()
+ fd.parent.children.Remove(fd)
+ fd.parent.childrenMu.Unlock()
+ fd.parent.DecRefLocked() // Drop the ref on the parent.
+}
+
+// Init must be called before first use of fd. It inserts fd into the
+// filesystem tree.
+//
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) Init(c *Connection, parent *ControlFD, name string, mode linux.FileMode, impl ControlFDImpl) {
+ // Initialize fd with 1 ref which is transferred to c via c.insertFD().
+ fd.controlFDRefs.InitRefs()
+ fd.conn = c
+ fd.id = c.insertFD(fd)
+ fd.name = name
+ fd.ftype = mode.FileType()
+ fd.impl = impl
+ fd.setParentLocked(parent)
+}
+
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) setParentLocked(parent *ControlFD) {
+ fd.parent = parent
+ if parent != nil {
+ parent.IncRef() // Hold a ref on parent.
+ parent.childrenMu.Lock()
+ parent.children.PushBack(fd)
+ parent.childrenMu.Unlock()
+ }
+}
+
+// FileType returns the file mode only containing the file type bits.
+func (fd *ControlFD) FileType() linux.FileMode {
+ return fd.ftype
+}
+
+// IsDir indicates whether fd represents a directory.
+func (fd *ControlFD) IsDir() bool {
+ return fd.ftype == unix.S_IFDIR
+}
+
+// IsRegular indicates whether fd represents a regular file.
+func (fd *ControlFD) IsRegular() bool {
+ return fd.ftype == unix.S_IFREG
+}
+
+// IsSymlink indicates whether fd represents a symbolic link.
+func (fd *ControlFD) IsSymlink() bool {
+ return fd.ftype == unix.S_IFLNK
+}
+
+// IsSocket indicates whether fd represents a socket.
+func (fd *ControlFD) IsSocket() bool {
+ return fd.ftype == unix.S_IFSOCK
+}
+
+// NameLocked returns the backing file's last component name.
+//
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) NameLocked() string {
+ return fd.name
+}
+
+// ParentLocked returns the parent control FD. Note that parent might be a
+// control FD from another connection on this server. So its ID must not
+// returned on this connection because FDIDs are local to their connection.
+//
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) ParentLocked() ControlFDImpl {
+ if fd.parent == nil {
+ return nil
+ }
+ return fd.parent.impl
+}
+
+// ID returns fd's ID.
+func (fd *ControlFD) ID() FDID {
+ return fd.id
+}
+
+// FilePath returns the absolute path of the file fd was opened on. This is
+// expensive and must not be called on hot paths. FilePath acquires the rename
+// mutex for reading so callers should not be holding it.
+func (fd *ControlFD) FilePath() string {
+ // Lock the rename mutex for reading to ensure that the filesystem tree is not
+ // changed while we traverse it upwards.
+ fd.conn.server.RenameMu.RLock()
+ defer fd.conn.server.RenameMu.RUnlock()
+ return fd.FilePathLocked()
+}
+
+// FilePathLocked is the same as FilePath with the additional precondition.
+//
+// Precondition: server's rename mutex must be at least read locked.
+func (fd *ControlFD) FilePathLocked() string {
+ // Walk upwards and prepend name to res.
+ var res fspath.Builder
+ for fd != nil {
+ res.PrependComponent(fd.name)
+ fd = fd.parent
+ }
+ return res.String()
+}
+
+// ForEachOpenFD executes fn on each FD opened on fd.
+func (fd *ControlFD) ForEachOpenFD(fn func(ofd OpenFDImpl)) {
+ fd.openFDsMu.RLock()
+ defer fd.openFDsMu.RUnlock()
+ for ofd := fd.openFDs.Front(); ofd != nil; ofd = ofd.Next() {
+ fn(ofd.impl)
+ }
+}
+
+// OpenFD represents an open file descriptor on the protocol. It resonates
+// closely with a Linux file descriptor. Its operations are limited to the
+// file. Its operations are not allowed to modify or traverse the filesystem
+// tree. See OpenFDImpl for the supported operations.
+//
+// Reference Model:
+// * An OpenFD takes a reference on the control FD it was opened on.
+type OpenFD struct {
+ openFDRefs
+ openFDEntry
+
+ // All the following fields are immutable.
+
+ // controlFD is the ControlFD on which this FD was opened. OpenFD holds a ref
+ // on controlFD for its entire lifetime.
+ controlFD *ControlFD
+
+ // id is the unique FD identifier which identifies this FD on its connection.
+ id FDID
+
+ // Access mode for this FD.
+ readable bool
+ writable bool
+
+ // impl is the open FD implementation which embeds this struct. It
+ // contains all the implementation specific details.
+ impl OpenFDImpl
+}
+
+var _ genericFD = (*OpenFD)(nil)
+
+// ID returns fd's ID.
+func (fd *OpenFD) ID() FDID {
+ return fd.id
+}
+
+// ControlFD returns the control FD on which this FD was opened.
+func (fd *OpenFD) ControlFD() ControlFDImpl {
+ return fd.controlFD.impl
+}
+
+// DecRef implements refsvfs2.RefCounter.DecRef. Note that the context
+// parameter should never be used. It exists solely to comply with the
+// refsvfs2.RefCounter interface.
+func (fd *OpenFD) DecRef(context.Context) {
+ fd.openFDRefs.DecRef(func() {
+ fd.controlFD.openFDsMu.Lock()
+ fd.controlFD.openFDs.Remove(fd)
+ fd.controlFD.openFDsMu.Unlock()
+ fd.controlFD.DecRef(nil) // Drop the ref on the control FD.
+ fd.impl.Close(fd.controlFD.conn)
+ })
+}
+
+// Init must be called before first use of fd.
+func (fd *OpenFD) Init(cfd *ControlFD, flags uint32, impl OpenFDImpl) {
+ // Initialize fd with 1 ref which is transferred to c via c.insertFD().
+ fd.openFDRefs.InitRefs()
+ fd.controlFD = cfd
+ fd.id = cfd.conn.insertFD(fd)
+ accessMode := flags & unix.O_ACCMODE
+ fd.readable = accessMode == unix.O_RDONLY || accessMode == unix.O_RDWR
+ fd.writable = accessMode == unix.O_WRONLY || accessMode == unix.O_RDWR
+ fd.impl = impl
+ cfd.IncRef() // Holds a ref on cfd for its lifetime.
+ cfd.openFDsMu.Lock()
+ cfd.openFDs.PushBack(fd)
+ cfd.openFDsMu.Unlock()
+}
+
+// ControlFDImpl contains implementation details for a ControlFD.
+// Implementations of ControlFDImpl should contain their associated ControlFD
+// by value as their first field.
+//
+// The operations that perform path traversal or any modification to the
+// filesystem tree must synchronize those modifications with the server's
+// rename mutex.
+type ControlFDImpl interface {
+ FD() *ControlFD
+ Close(c *Connection)
+ Stat(c *Connection, comm Communicator) (uint32, error)
+ SetStat(c *Connection, comm Communicator, stat SetStatReq) (uint32, error)
+ Walk(c *Connection, comm Communicator, path StringArray) (uint32, error)
+ WalkStat(c *Connection, comm Communicator, path StringArray) (uint32, error)
+ Open(c *Connection, comm Communicator, flags uint32) (uint32, error)
+ OpenCreate(c *Connection, comm Communicator, mode linux.FileMode, uid UID, gid GID, name string, flags uint32) (uint32, error)
+ Mkdir(c *Connection, comm Communicator, mode linux.FileMode, uid UID, gid GID, name string) (uint32, error)
+ Mknod(c *Connection, comm Communicator, mode linux.FileMode, uid UID, gid GID, name string, minor uint32, major uint32) (uint32, error)
+ Symlink(c *Connection, comm Communicator, name string, target string, uid UID, gid GID) (uint32, error)
+ Link(c *Connection, comm Communicator, dir ControlFDImpl, name string) (uint32, error)
+ StatFS(c *Connection, comm Communicator) (uint32, error)
+ Readlink(c *Connection, comm Communicator) (uint32, error)
+ Connect(c *Connection, comm Communicator, sockType uint32) error
+ Unlink(c *Connection, name string, flags uint32) error
+ RenameLocked(c *Connection, newDir ControlFDImpl, newName string) (func(ControlFDImpl), func(), error)
+ GetXattr(c *Connection, comm Communicator, name string, size uint32) (uint32, error)
+ SetXattr(c *Connection, name string, value string, flags uint32) error
+ ListXattr(c *Connection, comm Communicator, size uint64) (uint32, error)
+ RemoveXattr(c *Connection, comm Communicator, name string) error
+}
+
+// OpenFDImpl contains implementation details for a OpenFD. Implementations of
+// OpenFDImpl should contain their associated OpenFD by value as their first
+// field.
+//
+// Since these operations do not perform any path traversal or any modification
+// to the filesystem tree, there is no need to synchronize with rename
+// operations.
+type OpenFDImpl interface {
+ FD() *OpenFD
+ Close(c *Connection)
+ Stat(c *Connection, comm Communicator) (uint32, error)
+ Sync(c *Connection) error
+ Write(c *Connection, comm Communicator, buf []byte, off uint64) (uint32, error)
+ Read(c *Connection, comm Communicator, off uint64, count uint32) (uint32, error)
+ Allocate(c *Connection, mode, off, length uint64) error
+ Flush(c *Connection) error
+ Getdent64(c *Connection, comm Communicator, count uint32, seek0 bool) (uint32, error)
+}
diff --git a/pkg/lisafs/handlers.go b/pkg/lisafs/handlers.go
new file mode 100644
index 000000000..82807734d
--- /dev/null
+++ b/pkg/lisafs/handlers.go
@@ -0,0 +1,768 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "fmt"
+ "path"
+ "path/filepath"
+ "strings"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
+)
+
+const (
+ allowedOpenFlags = unix.O_ACCMODE | unix.O_TRUNC
+ setStatSupportedMask = unix.STATX_MODE | unix.STATX_UID | unix.STATX_GID | unix.STATX_SIZE | unix.STATX_ATIME | unix.STATX_MTIME
+)
+
+// RPCHandler defines a handler that is invoked when the associated message is
+// received. The handler is responsible for:
+//
+// * Unmarshalling the request from the passed payload and interpreting it.
+// * Marshalling the response into the communicator's payload buffer.
+// * Return the number of payload bytes written.
+// * Donate any FDs (if needed) to comm which will in turn donate it to client.
+type RPCHandler func(c *Connection, comm Communicator, payloadLen uint32) (uint32, error)
+
+var handlers = [...]RPCHandler{
+ Error: ErrorHandler,
+ Mount: MountHandler,
+ Channel: ChannelHandler,
+ FStat: FStatHandler,
+ SetStat: SetStatHandler,
+ Walk: WalkHandler,
+ WalkStat: WalkStatHandler,
+ OpenAt: OpenAtHandler,
+ OpenCreateAt: OpenCreateAtHandler,
+ Close: CloseHandler,
+ FSync: FSyncHandler,
+ PWrite: PWriteHandler,
+ PRead: PReadHandler,
+ MkdirAt: MkdirAtHandler,
+ MknodAt: MknodAtHandler,
+ SymlinkAt: SymlinkAtHandler,
+ LinkAt: LinkAtHandler,
+ FStatFS: FStatFSHandler,
+ FAllocate: FAllocateHandler,
+ ReadLinkAt: ReadLinkAtHandler,
+ Flush: FlushHandler,
+ Connect: ConnectHandler,
+ UnlinkAt: UnlinkAtHandler,
+ RenameAt: RenameAtHandler,
+ Getdents64: Getdents64Handler,
+ FGetXattr: FGetXattrHandler,
+ FSetXattr: FSetXattrHandler,
+ FListXattr: FListXattrHandler,
+ FRemoveXattr: FRemoveXattrHandler,
+}
+
+// ErrorHandler handles Error message.
+func ErrorHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ // Client should never send Error.
+ return 0, unix.EINVAL
+}
+
+// MountHandler handles the Mount RPC. Note that there can not be concurrent
+// executions of MountHandler on a connection because the connection enforces
+// that Mount is the first message on the connection. Only after the connection
+// has been successfully mounted can other channels be created.
+func MountHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req MountReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ mountPath := path.Clean(string(req.MountPath))
+ if !filepath.IsAbs(mountPath) {
+ log.Warningf("mountPath %q is not absolute", mountPath)
+ return 0, unix.EINVAL
+ }
+
+ if c.mounted {
+ log.Warningf("connection has already been mounted at %q", mountPath)
+ return 0, unix.EBUSY
+ }
+
+ rootFD, rootIno, err := c.ServerImpl().Mount(c, mountPath)
+ if err != nil {
+ return 0, err
+ }
+
+ c.server.addMountPoint(rootFD.FD())
+ c.mounted = true
+ resp := MountResp{
+ Root: rootIno,
+ SupportedMs: c.ServerImpl().SupportedMessages(),
+ MaxMessageSize: primitive.Uint32(c.ServerImpl().MaxMessageSize()),
+ }
+ respPayloadLen := uint32(resp.SizeBytes())
+ resp.MarshalBytes(comm.PayloadBuf(respPayloadLen))
+ return respPayloadLen, nil
+}
+
+// ChannelHandler handles the Channel RPC.
+func ChannelHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ ch, desc, fdSock, err := c.createChannel(c.ServerImpl().MaxMessageSize())
+ if err != nil {
+ return 0, err
+ }
+
+ // Start servicing the channel in a separate goroutine.
+ c.activeWg.Add(1)
+ go func() {
+ if err := c.service(ch); err != nil {
+ // Don't log shutdown error which is expected during server shutdown.
+ if _, ok := err.(flipcall.ShutdownError); !ok {
+ log.Warningf("lisafs.Connection.service(channel = @%p): %v", ch, err)
+ }
+ }
+ c.activeWg.Done()
+ }()
+
+ clientDataFD, err := unix.Dup(desc.FD)
+ if err != nil {
+ unix.Close(fdSock)
+ ch.shutdown()
+ return 0, err
+ }
+
+ // Respond to client with successful channel creation message.
+ if err := comm.DonateFD(clientDataFD); err != nil {
+ return 0, err
+ }
+ if err := comm.DonateFD(fdSock); err != nil {
+ return 0, err
+ }
+ resp := ChannelResp{
+ dataOffset: desc.Offset,
+ dataLength: uint64(desc.Length),
+ }
+ respLen := uint32(resp.SizeBytes())
+ resp.MarshalUnsafe(comm.PayloadBuf(respLen))
+ return respLen, nil
+}
+
+// FStatHandler handles the FStat RPC.
+func FStatHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req StatReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.lookupFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+
+ switch t := fd.(type) {
+ case *ControlFD:
+ return t.impl.Stat(c, comm)
+ case *OpenFD:
+ return t.impl.Stat(c, comm)
+ default:
+ panic(fmt.Sprintf("unknown fd type %T", t))
+ }
+}
+
+// SetStatHandler handles the SetStat RPC.
+func SetStatHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+
+ var req SetStatReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+
+ if req.Mask&^setStatSupportedMask != 0 {
+ return 0, unix.EPERM
+ }
+
+ return fd.impl.SetStat(c, comm, req)
+}
+
+// WalkHandler handles the Walk RPC.
+func WalkHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req WalkReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+ for _, name := range req.Path {
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+ }
+
+ return fd.impl.Walk(c, comm, req.Path)
+}
+
+// WalkStatHandler handles the WalkStat RPC.
+func WalkStatHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req WalkReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+
+ // Note that this fd is allowed to not actually be a directory when the
+ // only path component to walk is "" (self).
+ if !fd.IsDir() {
+ if len(req.Path) > 1 || (len(req.Path) == 1 && len(req.Path[0]) > 0) {
+ return 0, unix.ENOTDIR
+ }
+ }
+ for i, name := range req.Path {
+ // First component is allowed to be "".
+ if i == 0 && len(name) == 0 {
+ continue
+ }
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+ }
+
+ return fd.impl.WalkStat(c, comm, req.Path)
+}
+
+// OpenAtHandler handles the OpenAt RPC.
+func OpenAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req OpenAtReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ // Only keep allowed open flags.
+ if allowedFlags := req.Flags & allowedOpenFlags; allowedFlags != req.Flags {
+ log.Debugf("discarding open flags that are not allowed: old open flags = %d, new open flags = %d", req.Flags, allowedFlags)
+ req.Flags = allowedFlags
+ }
+
+ accessMode := req.Flags & unix.O_ACCMODE
+ trunc := req.Flags&unix.O_TRUNC != 0
+ if c.readonly && (accessMode != unix.O_RDONLY || trunc) {
+ return 0, unix.EROFS
+ }
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if fd.IsDir() {
+ // Directory is not truncatable and must be opened with O_RDONLY.
+ if accessMode != unix.O_RDONLY || trunc {
+ return 0, unix.EISDIR
+ }
+ }
+
+ return fd.impl.Open(c, comm, req.Flags)
+}
+
+// OpenCreateAtHandler handles the OpenCreateAt RPC.
+func OpenCreateAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req OpenCreateAtReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ // Only keep allowed open flags.
+ if allowedFlags := req.Flags & allowedOpenFlags; allowedFlags != req.Flags {
+ log.Debugf("discarding open flags that are not allowed: old open flags = %d, new open flags = %d", req.Flags, allowedFlags)
+ req.Flags = allowedFlags
+ }
+
+ name := string(req.Name)
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+
+ return fd.impl.OpenCreate(c, comm, req.Mode, req.UID, req.GID, name, uint32(req.Flags))
+}
+
+// CloseHandler handles the Close RPC.
+func CloseHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req CloseReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+ for _, fd := range req.FDs {
+ c.RemoveFD(fd)
+ }
+
+ // There is no response message for this.
+ return 0, nil
+}
+
+// FSyncHandler handles the FSync RPC.
+func FSyncHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req FsyncReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ // Return the first error we encounter, but sync everything we can
+ // regardless.
+ var retErr error
+ for _, fdid := range req.FDs {
+ if err := c.fsyncFD(fdid); err != nil && retErr == nil {
+ retErr = err
+ }
+ }
+
+ // There is no response message for this.
+ return 0, retErr
+}
+
+func (c *Connection) fsyncFD(id FDID) error {
+ fd, err := c.LookupOpenFD(id)
+ if err != nil {
+ return err
+ }
+ return fd.impl.Sync(c)
+}
+
+// PWriteHandler handles the PWrite RPC.
+func PWriteHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req PWriteReq
+ // Note that it is an optimized Unmarshal operation which avoids any buffer
+ // allocation and copying. req.Buf just points to payload. This is safe to do
+ // as the handler owns payload and req's lifetime is limited to the handler.
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+ fd, err := c.LookupOpenFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ if !fd.writable {
+ return 0, unix.EBADF
+ }
+ return fd.impl.Write(c, comm, req.Buf, uint64(req.Offset))
+}
+
+// PReadHandler handles the PRead RPC.
+func PReadHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req PReadReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupOpenFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.readable {
+ return 0, unix.EBADF
+ }
+ return fd.impl.Read(c, comm, req.Offset, req.Count)
+}
+
+// MkdirAtHandler handles the MkdirAt RPC.
+func MkdirAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req MkdirAtReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ name := string(req.Name)
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+ return fd.impl.Mkdir(c, comm, req.Mode, req.UID, req.GID, name)
+}
+
+// MknodAtHandler handles the MknodAt RPC.
+func MknodAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req MknodAtReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ name := string(req.Name)
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+ return fd.impl.Mknod(c, comm, req.Mode, req.UID, req.GID, name, uint32(req.Minor), uint32(req.Major))
+}
+
+// SymlinkAtHandler handles the SymlinkAt RPC.
+func SymlinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req SymlinkAtReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ name := string(req.Name)
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+ return fd.impl.Symlink(c, comm, name, string(req.Target), req.UID, req.GID)
+}
+
+// LinkAtHandler handles the LinkAt RPC.
+func LinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req LinkAtReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ name := string(req.Name)
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+
+ targetFD, err := c.LookupControlFD(req.Target)
+ if err != nil {
+ return 0, err
+ }
+ return targetFD.impl.Link(c, comm, fd.impl, name)
+}
+
+// FStatFSHandler handles the FStatFS RPC.
+func FStatFSHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req FStatFSReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ return fd.impl.StatFS(c, comm)
+}
+
+// FAllocateHandler handles the FAllocate RPC.
+func FAllocateHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req FAllocateReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupOpenFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.writable {
+ return 0, unix.EBADF
+ }
+ return 0, fd.impl.Allocate(c, req.Mode, req.Offset, req.Length)
+}
+
+// ReadLinkAtHandler handles the ReadLinkAt RPC.
+func ReadLinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req ReadLinkAtReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsSymlink() {
+ return 0, unix.EINVAL
+ }
+ return fd.impl.Readlink(c, comm)
+}
+
+// FlushHandler handles the Flush RPC.
+func FlushHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req FlushReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupOpenFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+
+ return 0, fd.impl.Flush(c)
+}
+
+// ConnectHandler handles the Connect RPC.
+func ConnectHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req ConnectReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsSocket() {
+ return 0, unix.ENOTSOCK
+ }
+ return 0, fd.impl.Connect(c, comm, req.SockType)
+}
+
+// UnlinkAtHandler handles the UnlinkAt RPC.
+func UnlinkAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req UnlinkAtReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ name := string(req.Name)
+ if err := checkSafeName(name); err != nil {
+ return 0, err
+ }
+
+ fd, err := c.LookupControlFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+ return 0, fd.impl.Unlink(c, name, uint32(req.Flags))
+}
+
+// RenameAtHandler handles the RenameAt RPC.
+func RenameAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req RenameAtReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ newName := string(req.NewName)
+ if err := checkSafeName(newName); err != nil {
+ return 0, err
+ }
+
+ renamed, err := c.LookupControlFD(req.Renamed)
+ if err != nil {
+ return 0, err
+ }
+ defer renamed.DecRef(nil)
+
+ newDir, err := c.LookupControlFD(req.NewDir)
+ if err != nil {
+ return 0, err
+ }
+ defer newDir.DecRef(nil)
+ if !newDir.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+
+ // Hold RenameMu for writing during rename, this is important.
+ c.server.RenameMu.Lock()
+ defer c.server.RenameMu.Unlock()
+
+ if renamed.parent == nil {
+ // renamed is root.
+ return 0, unix.EBUSY
+ }
+
+ oldParentPath := renamed.parent.FilePathLocked()
+ oldPath := oldParentPath + "/" + renamed.name
+ if newName == renamed.name && oldParentPath == newDir.FilePathLocked() {
+ // Nothing to do.
+ return 0, nil
+ }
+
+ updateControlFD, cleanUp, err := renamed.impl.RenameLocked(c, newDir.impl, newName)
+ if err != nil {
+ return 0, err
+ }
+
+ c.server.forEachMountPoint(func(root *ControlFD) {
+ if !strings.HasPrefix(oldPath, root.name) {
+ return
+ }
+ pit := fspath.Parse(oldPath[len(root.name):]).Begin
+ root.renameRecursiveLocked(newDir, newName, pit, updateControlFD)
+ })
+
+ if cleanUp != nil {
+ cleanUp()
+ }
+ return 0, nil
+}
+
+// Precondition: rename mutex must be locked for writing.
+func (fd *ControlFD) renameRecursiveLocked(newDir *ControlFD, newName string, pit fspath.Iterator, updateControlFD func(ControlFDImpl)) {
+ if !pit.Ok() {
+ // fd should be renamed.
+ fd.clearParentLocked()
+ fd.setParentLocked(newDir)
+ fd.name = newName
+ if updateControlFD != nil {
+ updateControlFD(fd.impl)
+ }
+ return
+ }
+
+ cur := pit.String()
+ next := pit.Next()
+ // No need to hold fd.childrenMu because RenameMu is locked for writing.
+ for child := fd.children.Front(); child != nil; child = child.Next() {
+ if child.name == cur {
+ child.renameRecursiveLocked(newDir, newName, next, updateControlFD)
+ }
+ }
+}
+
+// Getdents64Handler handles the Getdents64 RPC.
+func Getdents64Handler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req Getdents64Req
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupOpenFD(req.DirFD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ if !fd.controlFD.IsDir() {
+ return 0, unix.ENOTDIR
+ }
+
+ seek0 := false
+ if req.Count < 0 {
+ seek0 = true
+ req.Count = -req.Count
+ }
+ return fd.impl.Getdent64(c, comm, uint32(req.Count), seek0)
+}
+
+// FGetXattrHandler handles the FGetXattr RPC.
+func FGetXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req FGetXattrReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ return fd.impl.GetXattr(c, comm, string(req.Name), uint32(req.BufSize))
+}
+
+// FSetXattrHandler handles the FSetXattr RPC.
+func FSetXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req FSetXattrReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ return 0, fd.impl.SetXattr(c, string(req.Name), string(req.Value), uint32(req.Flags))
+}
+
+// FListXattrHandler handles the FListXattr RPC.
+func FListXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ var req FListXattrReq
+ req.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ return fd.impl.ListXattr(c, comm, req.Size)
+}
+
+// FRemoveXattrHandler handles the FRemoveXattr RPC.
+func FRemoveXattrHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) {
+ if c.readonly {
+ return 0, unix.EROFS
+ }
+ var req FRemoveXattrReq
+ req.UnmarshalBytes(comm.PayloadBuf(payloadLen))
+
+ fd, err := c.LookupControlFD(req.FD)
+ if err != nil {
+ return 0, err
+ }
+ defer fd.DecRef(nil)
+ return 0, fd.impl.RemoveXattr(c, comm, string(req.Name))
+}
+
+// checkSafeName validates the name and returns nil or returns an error.
+func checkSafeName(name string) error {
+ if name != "" && !strings.Contains(name, "/") && name != "." && name != ".." {
+ return nil
+ }
+ return unix.EINVAL
+}
diff --git a/pkg/lisafs/lisafs.go b/pkg/lisafs/lisafs.go
new file mode 100644
index 000000000..4d8e956ab
--- /dev/null
+++ b/pkg/lisafs/lisafs.go
@@ -0,0 +1,18 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package lisafs (LInux SAndbox FileSystem) defines the protocol for
+// filesystem RPCs between an untrusted Sandbox (client) and a trusted
+// filesystem server.
+package lisafs
diff --git a/pkg/lisafs/message.go b/pkg/lisafs/message.go
new file mode 100644
index 000000000..722afd0be
--- /dev/null
+++ b/pkg/lisafs/message.go
@@ -0,0 +1,1251 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "math"
+ "os"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
+)
+
+// Messages have two parts:
+// * A transport header used to decipher received messages.
+// * A byte array referred to as "payload" which contains the actual message.
+//
+// "dataLen" refers to the size of both combined.
+
+// MID (message ID) is used to identify messages to parse from payload.
+//
+// +marshal slice:MIDSlice
+type MID uint16
+
+// These constants are used to identify their corresponding message types.
+const (
+ // Error is only used in responses to pass errors to client.
+ Error MID = 0
+
+ // Mount is used to establish connection between the client and server mount
+ // point. lisafs requires that the client makes a successful Mount RPC before
+ // making other RPCs.
+ Mount MID = 1
+
+ // Channel requests to start a new communicational channel.
+ Channel MID = 2
+
+ // FStat requests the stat(2) results for a specified file.
+ FStat MID = 3
+
+ // SetStat requests to change file attributes. Note that there is no one
+ // corresponding Linux syscall. This is a conglomeration of fchmod(2),
+ // fchown(2), ftruncate(2) and futimesat(2).
+ SetStat MID = 4
+
+ // Walk requests to walk the specified path starting from the specified
+ // directory. Server-side path traversal is terminated preemptively on
+ // symlinks entries because they can cause non-linear traversal.
+ Walk MID = 5
+
+ // WalkStat is the same as Walk, except the following differences:
+ // * If the first path component is "", then it also returns stat results
+ // for the directory where the walk starts.
+ // * Does not return Inode, just the Stat results for each path component.
+ WalkStat MID = 6
+
+ // OpenAt is analogous to openat(2). It does not perform any walk. It merely
+ // duplicates the control FD with the open flags passed.
+ OpenAt MID = 7
+
+ // OpenCreateAt is analogous to openat(2) with O_CREAT|O_EXCL added to flags.
+ // It also returns the newly created file inode.
+ OpenCreateAt MID = 8
+
+ // Close is analogous to close(2) but can work on multiple FDs.
+ Close MID = 9
+
+ // FSync is analogous to fsync(2) but can work on multiple FDs.
+ FSync MID = 10
+
+ // PWrite is analogous to pwrite(2).
+ PWrite MID = 11
+
+ // PRead is analogous to pread(2).
+ PRead MID = 12
+
+ // MkdirAt is analogous to mkdirat(2).
+ MkdirAt MID = 13
+
+ // MknodAt is analogous to mknodat(2).
+ MknodAt MID = 14
+
+ // SymlinkAt is analogous to symlinkat(2).
+ SymlinkAt MID = 15
+
+ // LinkAt is analogous to linkat(2).
+ LinkAt MID = 16
+
+ // FStatFS is analogous to fstatfs(2).
+ FStatFS MID = 17
+
+ // FAllocate is analogous to fallocate(2).
+ FAllocate MID = 18
+
+ // ReadLinkAt is analogous to readlinkat(2).
+ ReadLinkAt MID = 19
+
+ // Flush cleans up the file state. Its behavior is implementation
+ // dependent and might not even be supported in server implementations.
+ Flush MID = 20
+
+ // Connect is loosely analogous to connect(2).
+ Connect MID = 21
+
+ // UnlinkAt is analogous to unlinkat(2).
+ UnlinkAt MID = 22
+
+ // RenameAt is loosely analogous to renameat(2).
+ RenameAt MID = 23
+
+ // Getdents64 is analogous to getdents64(2).
+ Getdents64 MID = 24
+
+ // FGetXattr is analogous to fgetxattr(2).
+ FGetXattr MID = 25
+
+ // FSetXattr is analogous to fsetxattr(2).
+ FSetXattr MID = 26
+
+ // FListXattr is analogous to flistxattr(2).
+ FListXattr MID = 27
+
+ // FRemoveXattr is analogous to fremovexattr(2).
+ FRemoveXattr MID = 28
+)
+
+const (
+ // NoUID is a sentinel used to indicate no valid UID.
+ NoUID UID = math.MaxUint32
+
+ // NoGID is a sentinel used to indicate no valid GID.
+ NoGID GID = math.MaxUint32
+)
+
+// MaxMessageSize is the recommended max message size that can be used by
+// connections. Server implementations may choose to use other values.
+func MaxMessageSize() uint32 {
+ // Return HugePageSize - PageSize so that when flipcall packet window is
+ // created with MaxMessageSize() + flipcall header size + channel header
+ // size, HugePageSize is allocated and can be backed by a single huge page
+ // if supported by the underlying memfd.
+ return uint32(hostarch.HugePageSize - os.Getpagesize())
+}
+
+// TODO(gvisor.dev/issue/6450): Once this is resolved:
+// * Update manual implementations and function signatures.
+// * Update RPC handlers and appropriate callers to handle errors correctly.
+// * Update manual implementations to get rid of buffer shifting.
+
+// UID represents a user ID.
+//
+// +marshal
+type UID uint32
+
+// Ok returns true if uid is not NoUID.
+func (uid UID) Ok() bool {
+ return uid != NoUID
+}
+
+// GID represents a group ID.
+//
+// +marshal
+type GID uint32
+
+// Ok returns true if gid is not NoGID.
+func (gid GID) Ok() bool {
+ return gid != NoGID
+}
+
+// NoopMarshal is a noop implementation of marshal.Marshallable.MarshalBytes.
+func NoopMarshal([]byte) {}
+
+// NoopUnmarshal is a noop implementation of marshal.Marshallable.UnmarshalBytes.
+func NoopUnmarshal([]byte) {}
+
+// SizedString represents a string in memory. The marshalled string bytes are
+// preceded by a uint32 signifying the string length.
+type SizedString string
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (s *SizedString) SizeBytes() int {
+ return (*primitive.Uint32)(nil).SizeBytes() + len(*s)
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (s *SizedString) MarshalBytes(dst []byte) {
+ strLen := primitive.Uint32(len(*s))
+ strLen.MarshalUnsafe(dst)
+ dst = dst[strLen.SizeBytes():]
+ // Copy without any allocation.
+ copy(dst[:strLen], *s)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (s *SizedString) UnmarshalBytes(src []byte) {
+ var strLen primitive.Uint32
+ strLen.UnmarshalUnsafe(src)
+ src = src[strLen.SizeBytes():]
+ // Take the hit, this leads to an allocation + memcpy. No way around it.
+ *s = SizedString(src[:strLen])
+}
+
+// StringArray represents an array of SizedStrings in memory. The marshalled
+// array data is preceded by a uint32 signifying the array length.
+type StringArray []string
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (s *StringArray) SizeBytes() int {
+ size := (*primitive.Uint32)(nil).SizeBytes()
+ for _, str := range *s {
+ sstr := SizedString(str)
+ size += sstr.SizeBytes()
+ }
+ return size
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (s *StringArray) MarshalBytes(dst []byte) {
+ arrLen := primitive.Uint32(len(*s))
+ arrLen.MarshalUnsafe(dst)
+ dst = dst[arrLen.SizeBytes():]
+ for _, str := range *s {
+ sstr := SizedString(str)
+ sstr.MarshalBytes(dst)
+ dst = dst[sstr.SizeBytes():]
+ }
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (s *StringArray) UnmarshalBytes(src []byte) {
+ var arrLen primitive.Uint32
+ arrLen.UnmarshalUnsafe(src)
+ src = src[arrLen.SizeBytes():]
+
+ if cap(*s) < int(arrLen) {
+ *s = make([]string, arrLen)
+ } else {
+ *s = (*s)[:arrLen]
+ }
+
+ for i := primitive.Uint32(0); i < arrLen; i++ {
+ var sstr SizedString
+ sstr.UnmarshalBytes(src)
+ src = src[sstr.SizeBytes():]
+ (*s)[i] = string(sstr)
+ }
+}
+
+// Inode represents an inode on the remote filesystem.
+//
+// +marshal slice:InodeSlice
+type Inode struct {
+ ControlFD FDID
+ _ uint32 // Need to make struct packed.
+ Stat linux.Statx
+}
+
+// MountReq represents a Mount request.
+type MountReq struct {
+ MountPath SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (m *MountReq) SizeBytes() int {
+ return m.MountPath.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (m *MountReq) MarshalBytes(dst []byte) {
+ m.MountPath.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (m *MountReq) UnmarshalBytes(src []byte) {
+ m.MountPath.UnmarshalBytes(src)
+}
+
+// MountResp represents a Mount response.
+type MountResp struct {
+ Root Inode
+ // MaxMessageSize is the maximum size of messages communicated between the
+ // client and server in bytes. This includes the communication header.
+ MaxMessageSize primitive.Uint32
+ // SupportedMs holds all the supported messages.
+ SupportedMs []MID
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (m *MountResp) SizeBytes() int {
+ return m.Root.SizeBytes() +
+ m.MaxMessageSize.SizeBytes() +
+ (*primitive.Uint16)(nil).SizeBytes() +
+ (len(m.SupportedMs) * (*MID)(nil).SizeBytes())
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (m *MountResp) MarshalBytes(dst []byte) {
+ m.Root.MarshalUnsafe(dst)
+ dst = dst[m.Root.SizeBytes():]
+ m.MaxMessageSize.MarshalUnsafe(dst)
+ dst = dst[m.MaxMessageSize.SizeBytes():]
+ numSupported := primitive.Uint16(len(m.SupportedMs))
+ numSupported.MarshalBytes(dst)
+ dst = dst[numSupported.SizeBytes():]
+ MarshalUnsafeMIDSlice(m.SupportedMs, dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (m *MountResp) UnmarshalBytes(src []byte) {
+ m.Root.UnmarshalUnsafe(src)
+ src = src[m.Root.SizeBytes():]
+ m.MaxMessageSize.UnmarshalUnsafe(src)
+ src = src[m.MaxMessageSize.SizeBytes():]
+ var numSupported primitive.Uint16
+ numSupported.UnmarshalBytes(src)
+ src = src[numSupported.SizeBytes():]
+ m.SupportedMs = make([]MID, numSupported)
+ UnmarshalUnsafeMIDSlice(m.SupportedMs, src)
+}
+
+// ChannelResp is the response to the create channel request.
+//
+// +marshal
+type ChannelResp struct {
+ dataOffset int64
+ dataLength uint64
+}
+
+// ErrorResp is returned to represent an error while handling a request.
+//
+// +marshal
+type ErrorResp struct {
+ errno uint32
+}
+
+// StatReq requests the stat results for the specified FD.
+//
+// +marshal
+type StatReq struct {
+ FD FDID
+}
+
+// SetStatReq is used to set attributeds on FDs.
+//
+// +marshal
+type SetStatReq struct {
+ FD FDID
+ _ uint32
+ Mask uint32
+ Mode uint32 // Only permissions part is settable.
+ UID UID
+ GID GID
+ Size uint64
+ Atime linux.Timespec
+ Mtime linux.Timespec
+}
+
+// SetStatResp is used to communicate SetStat results. It contains a mask
+// representing the failed changes. It also contains the errno of the failed
+// set attribute operation. If multiple operations failed then any of those
+// errnos can be returned.
+//
+// +marshal
+type SetStatResp struct {
+ FailureMask uint32
+ FailureErrNo uint32
+}
+
+// WalkReq is used to request to walk multiple path components at once. This
+// is used for both Walk and WalkStat.
+type WalkReq struct {
+ DirFD FDID
+ Path StringArray
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (w *WalkReq) SizeBytes() int {
+ return w.DirFD.SizeBytes() + w.Path.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (w *WalkReq) MarshalBytes(dst []byte) {
+ w.DirFD.MarshalUnsafe(dst)
+ dst = dst[w.DirFD.SizeBytes():]
+ w.Path.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (w *WalkReq) UnmarshalBytes(src []byte) {
+ w.DirFD.UnmarshalUnsafe(src)
+ src = src[w.DirFD.SizeBytes():]
+ w.Path.UnmarshalBytes(src)
+}
+
+// WalkStatus is used to indicate the reason for partial/unsuccessful server
+// side Walk operations. Please note that partial/unsuccessful walk operations
+// do not necessarily fail the RPC. The RPC is successful with a failure hint
+// which can be used by the client to infer server-side state.
+type WalkStatus = primitive.Uint8
+
+const (
+ // WalkSuccess indicates that all path components were successfully walked.
+ WalkSuccess WalkStatus = iota
+
+ // WalkComponentDoesNotExist indicates that the walk was prematurely
+ // terminated because an intermediate path component does not exist on
+ // server. The results of all previous existing path components is returned.
+ WalkComponentDoesNotExist
+
+ // WalkComponentSymlink indicates that the walk was prematurely
+ // terminated because an intermediate path component was a symlink. It is not
+ // safe to resolve symlinks remotely (unaware of mount points).
+ WalkComponentSymlink
+)
+
+// WalkResp is used to communicate the inodes walked by the server.
+type WalkResp struct {
+ Status WalkStatus
+ Inodes []Inode
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (w *WalkResp) SizeBytes() int {
+ return w.Status.SizeBytes() +
+ (*primitive.Uint32)(nil).SizeBytes() + (len(w.Inodes) * (*Inode)(nil).SizeBytes())
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (w *WalkResp) MarshalBytes(dst []byte) {
+ w.Status.MarshalUnsafe(dst)
+ dst = dst[w.Status.SizeBytes():]
+
+ numInodes := primitive.Uint32(len(w.Inodes))
+ numInodes.MarshalUnsafe(dst)
+ dst = dst[numInodes.SizeBytes():]
+
+ MarshalUnsafeInodeSlice(w.Inodes, dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (w *WalkResp) UnmarshalBytes(src []byte) {
+ w.Status.UnmarshalUnsafe(src)
+ src = src[w.Status.SizeBytes():]
+
+ var numInodes primitive.Uint32
+ numInodes.UnmarshalUnsafe(src)
+ src = src[numInodes.SizeBytes():]
+
+ if cap(w.Inodes) < int(numInodes) {
+ w.Inodes = make([]Inode, numInodes)
+ } else {
+ w.Inodes = w.Inodes[:numInodes]
+ }
+ UnmarshalUnsafeInodeSlice(w.Inodes, src)
+}
+
+// WalkStatResp is used to communicate stat results for WalkStat.
+type WalkStatResp struct {
+ Stats []linux.Statx
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (w *WalkStatResp) SizeBytes() int {
+ return (*primitive.Uint32)(nil).SizeBytes() + (len(w.Stats) * linux.SizeOfStatx)
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (w *WalkStatResp) MarshalBytes(dst []byte) {
+ numStats := primitive.Uint32(len(w.Stats))
+ numStats.MarshalUnsafe(dst)
+ dst = dst[numStats.SizeBytes():]
+
+ linux.MarshalUnsafeStatxSlice(w.Stats, dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (w *WalkStatResp) UnmarshalBytes(src []byte) {
+ var numStats primitive.Uint32
+ numStats.UnmarshalUnsafe(src)
+ src = src[numStats.SizeBytes():]
+
+ if cap(w.Stats) < int(numStats) {
+ w.Stats = make([]linux.Statx, numStats)
+ } else {
+ w.Stats = w.Stats[:numStats]
+ }
+ linux.UnmarshalUnsafeStatxSlice(w.Stats, src)
+}
+
+// OpenAtReq is used to open existing FDs with the specified flags.
+//
+// +marshal
+type OpenAtReq struct {
+ FD FDID
+ Flags uint32
+}
+
+// OpenAtResp is used to communicate the newly created FD.
+//
+// +marshal
+type OpenAtResp struct {
+ NewFD FDID
+}
+
+// +marshal
+type createCommon struct {
+ DirFD FDID
+ Mode linux.FileMode
+ _ uint16 // Need to make struct packed.
+ UID UID
+ GID GID
+}
+
+// OpenCreateAtReq is used to make OpenCreateAt requests.
+type OpenCreateAtReq struct {
+ createCommon
+ Name SizedString
+ Flags primitive.Uint32
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (o *OpenCreateAtReq) SizeBytes() int {
+ return o.createCommon.SizeBytes() + o.Name.SizeBytes() + o.Flags.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (o *OpenCreateAtReq) MarshalBytes(dst []byte) {
+ o.createCommon.MarshalUnsafe(dst)
+ dst = dst[o.createCommon.SizeBytes():]
+ o.Name.MarshalBytes(dst)
+ dst = dst[o.Name.SizeBytes():]
+ o.Flags.MarshalUnsafe(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (o *OpenCreateAtReq) UnmarshalBytes(src []byte) {
+ o.createCommon.UnmarshalUnsafe(src)
+ src = src[o.createCommon.SizeBytes():]
+ o.Name.UnmarshalBytes(src)
+ src = src[o.Name.SizeBytes():]
+ o.Flags.UnmarshalUnsafe(src)
+}
+
+// OpenCreateAtResp is used to communicate successful OpenCreateAt results.
+//
+// +marshal
+type OpenCreateAtResp struct {
+ Child Inode
+ NewFD FDID
+ _ uint32 // Need to make struct packed.
+}
+
+// FdArray is a utility struct which implements a marshallable type for
+// communicating an array of FDIDs. In memory, the array data is preceded by a
+// uint32 denoting the array length.
+type FdArray []FDID
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (f *FdArray) SizeBytes() int {
+ return (*primitive.Uint32)(nil).SizeBytes() + (len(*f) * (*FDID)(nil).SizeBytes())
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (f *FdArray) MarshalBytes(dst []byte) {
+ arrLen := primitive.Uint32(len(*f))
+ arrLen.MarshalUnsafe(dst)
+ dst = dst[arrLen.SizeBytes():]
+ MarshalUnsafeFDIDSlice(*f, dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (f *FdArray) UnmarshalBytes(src []byte) {
+ var arrLen primitive.Uint32
+ arrLen.UnmarshalUnsafe(src)
+ src = src[arrLen.SizeBytes():]
+ if cap(*f) < int(arrLen) {
+ *f = make(FdArray, arrLen)
+ } else {
+ *f = (*f)[:arrLen]
+ }
+ UnmarshalUnsafeFDIDSlice(*f, src)
+}
+
+// CloseReq is used to close(2) FDs.
+type CloseReq struct {
+ FDs FdArray
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (c *CloseReq) SizeBytes() int {
+ return c.FDs.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (c *CloseReq) MarshalBytes(dst []byte) {
+ c.FDs.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (c *CloseReq) UnmarshalBytes(src []byte) {
+ c.FDs.UnmarshalBytes(src)
+}
+
+// FsyncReq is used to fsync(2) FDs.
+type FsyncReq struct {
+ FDs FdArray
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (f *FsyncReq) SizeBytes() int {
+ return f.FDs.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (f *FsyncReq) MarshalBytes(dst []byte) {
+ f.FDs.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (f *FsyncReq) UnmarshalBytes(src []byte) {
+ f.FDs.UnmarshalBytes(src)
+}
+
+// PReadReq is used to pread(2) on an FD.
+//
+// +marshal
+type PReadReq struct {
+ Offset uint64
+ FD FDID
+ Count uint32
+}
+
+// PReadResp is used to return the result of pread(2).
+type PReadResp struct {
+ NumBytes primitive.Uint32
+ Buf []byte
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (r *PReadResp) SizeBytes() int {
+ return r.NumBytes.SizeBytes() + int(r.NumBytes)
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (r *PReadResp) MarshalBytes(dst []byte) {
+ r.NumBytes.MarshalUnsafe(dst)
+ dst = dst[r.NumBytes.SizeBytes():]
+ copy(dst[:r.NumBytes], r.Buf[:r.NumBytes])
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (r *PReadResp) UnmarshalBytes(src []byte) {
+ r.NumBytes.UnmarshalUnsafe(src)
+ src = src[r.NumBytes.SizeBytes():]
+
+ // We expect the client to have already allocated r.Buf. r.Buf probably
+ // (optimally) points to usermem. Directly copy into that.
+ copy(r.Buf[:r.NumBytes], src[:r.NumBytes])
+}
+
+// PWriteReq is used to pwrite(2) on an FD.
+type PWriteReq struct {
+ Offset primitive.Uint64
+ FD FDID
+ NumBytes primitive.Uint32
+ Buf []byte
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (w *PWriteReq) SizeBytes() int {
+ return w.Offset.SizeBytes() + w.FD.SizeBytes() + w.NumBytes.SizeBytes() + int(w.NumBytes)
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (w *PWriteReq) MarshalBytes(dst []byte) {
+ w.Offset.MarshalUnsafe(dst)
+ dst = dst[w.Offset.SizeBytes():]
+ w.FD.MarshalUnsafe(dst)
+ dst = dst[w.FD.SizeBytes():]
+ w.NumBytes.MarshalUnsafe(dst)
+ dst = dst[w.NumBytes.SizeBytes():]
+ copy(dst[:w.NumBytes], w.Buf[:w.NumBytes])
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (w *PWriteReq) UnmarshalBytes(src []byte) {
+ w.Offset.UnmarshalUnsafe(src)
+ src = src[w.Offset.SizeBytes():]
+ w.FD.UnmarshalUnsafe(src)
+ src = src[w.FD.SizeBytes():]
+ w.NumBytes.UnmarshalUnsafe(src)
+ src = src[w.NumBytes.SizeBytes():]
+
+ // This is an optimization. Assuming that the server is making this call, it
+ // is safe to just point to src rather than allocating and copying.
+ w.Buf = src[:w.NumBytes]
+}
+
+// PWriteResp is used to return the result of pwrite(2).
+//
+// +marshal
+type PWriteResp struct {
+ Count uint64
+}
+
+// MkdirAtReq is used to make MkdirAt requests.
+type MkdirAtReq struct {
+ createCommon
+ Name SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (m *MkdirAtReq) SizeBytes() int {
+ return m.createCommon.SizeBytes() + m.Name.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (m *MkdirAtReq) MarshalBytes(dst []byte) {
+ m.createCommon.MarshalUnsafe(dst)
+ dst = dst[m.createCommon.SizeBytes():]
+ m.Name.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (m *MkdirAtReq) UnmarshalBytes(src []byte) {
+ m.createCommon.UnmarshalUnsafe(src)
+ src = src[m.createCommon.SizeBytes():]
+ m.Name.UnmarshalBytes(src)
+}
+
+// MkdirAtResp is the response to a successful MkdirAt request.
+//
+// +marshal
+type MkdirAtResp struct {
+ ChildDir Inode
+}
+
+// MknodAtReq is used to make MknodAt requests.
+type MknodAtReq struct {
+ createCommon
+ Name SizedString
+ Minor primitive.Uint32
+ Major primitive.Uint32
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (m *MknodAtReq) SizeBytes() int {
+ return m.createCommon.SizeBytes() + m.Name.SizeBytes() + m.Minor.SizeBytes() + m.Major.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (m *MknodAtReq) MarshalBytes(dst []byte) {
+ m.createCommon.MarshalUnsafe(dst)
+ dst = dst[m.createCommon.SizeBytes():]
+ m.Name.MarshalBytes(dst)
+ dst = dst[m.Name.SizeBytes():]
+ m.Minor.MarshalUnsafe(dst)
+ dst = dst[m.Minor.SizeBytes():]
+ m.Major.MarshalUnsafe(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (m *MknodAtReq) UnmarshalBytes(src []byte) {
+ m.createCommon.UnmarshalUnsafe(src)
+ src = src[m.createCommon.SizeBytes():]
+ m.Name.UnmarshalBytes(src)
+ src = src[m.Name.SizeBytes():]
+ m.Minor.UnmarshalUnsafe(src)
+ src = src[m.Minor.SizeBytes():]
+ m.Major.UnmarshalUnsafe(src)
+}
+
+// MknodAtResp is the response to a successful MknodAt request.
+//
+// +marshal
+type MknodAtResp struct {
+ Child Inode
+}
+
+// SymlinkAtReq is used to make SymlinkAt request.
+type SymlinkAtReq struct {
+ DirFD FDID
+ Name SizedString
+ Target SizedString
+ UID UID
+ GID GID
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (s *SymlinkAtReq) SizeBytes() int {
+ return s.DirFD.SizeBytes() + s.Name.SizeBytes() + s.Target.SizeBytes() + s.UID.SizeBytes() + s.GID.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (s *SymlinkAtReq) MarshalBytes(dst []byte) {
+ s.DirFD.MarshalUnsafe(dst)
+ dst = dst[s.DirFD.SizeBytes():]
+ s.Name.MarshalBytes(dst)
+ dst = dst[s.Name.SizeBytes():]
+ s.Target.MarshalBytes(dst)
+ dst = dst[s.Target.SizeBytes():]
+ s.UID.MarshalUnsafe(dst)
+ dst = dst[s.UID.SizeBytes():]
+ s.GID.MarshalUnsafe(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (s *SymlinkAtReq) UnmarshalBytes(src []byte) {
+ s.DirFD.UnmarshalUnsafe(src)
+ src = src[s.DirFD.SizeBytes():]
+ s.Name.UnmarshalBytes(src)
+ src = src[s.Name.SizeBytes():]
+ s.Target.UnmarshalBytes(src)
+ src = src[s.Target.SizeBytes():]
+ s.UID.UnmarshalUnsafe(src)
+ src = src[s.UID.SizeBytes():]
+ s.GID.UnmarshalUnsafe(src)
+}
+
+// SymlinkAtResp is the response to a successful SymlinkAt request.
+//
+// +marshal
+type SymlinkAtResp struct {
+ Symlink Inode
+}
+
+// LinkAtReq is used to make LinkAt requests.
+type LinkAtReq struct {
+ DirFD FDID
+ Target FDID
+ Name SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (l *LinkAtReq) SizeBytes() int {
+ return l.DirFD.SizeBytes() + l.Target.SizeBytes() + l.Name.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (l *LinkAtReq) MarshalBytes(dst []byte) {
+ l.DirFD.MarshalUnsafe(dst)
+ dst = dst[l.DirFD.SizeBytes():]
+ l.Target.MarshalUnsafe(dst)
+ dst = dst[l.Target.SizeBytes():]
+ l.Name.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (l *LinkAtReq) UnmarshalBytes(src []byte) {
+ l.DirFD.UnmarshalUnsafe(src)
+ src = src[l.DirFD.SizeBytes():]
+ l.Target.UnmarshalUnsafe(src)
+ src = src[l.Target.SizeBytes():]
+ l.Name.UnmarshalBytes(src)
+}
+
+// LinkAtResp is used to respond to a successful LinkAt request.
+//
+// +marshal
+type LinkAtResp struct {
+ Link Inode
+}
+
+// FStatFSReq is used to request StatFS results for the specified FD.
+//
+// +marshal
+type FStatFSReq struct {
+ FD FDID
+}
+
+// StatFS is responded to a successful FStatFS request.
+//
+// +marshal
+type StatFS struct {
+ Type uint64
+ BlockSize int64
+ Blocks uint64
+ BlocksFree uint64
+ BlocksAvailable uint64
+ Files uint64
+ FilesFree uint64
+ NameLength uint64
+}
+
+// FAllocateReq is used to request to fallocate(2) an FD. This has no response.
+//
+// +marshal
+type FAllocateReq struct {
+ FD FDID
+ _ uint32
+ Mode uint64
+ Offset uint64
+ Length uint64
+}
+
+// ReadLinkAtReq is used to readlinkat(2) at the specified FD.
+//
+// +marshal
+type ReadLinkAtReq struct {
+ FD FDID
+}
+
+// ReadLinkAtResp is used to communicate ReadLinkAt results.
+type ReadLinkAtResp struct {
+ Target SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (r *ReadLinkAtResp) SizeBytes() int {
+ return r.Target.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (r *ReadLinkAtResp) MarshalBytes(dst []byte) {
+ r.Target.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (r *ReadLinkAtResp) UnmarshalBytes(src []byte) {
+ r.Target.UnmarshalBytes(src)
+}
+
+// FlushReq is used to make Flush requests.
+//
+// +marshal
+type FlushReq struct {
+ FD FDID
+}
+
+// ConnectReq is used to make a Connect request.
+//
+// +marshal
+type ConnectReq struct {
+ FD FDID
+ // SockType is used to specify the socket type to connect to. As a special
+ // case, SockType = 0 means that the socket type does not matter and the
+ // requester will accept any socket type.
+ SockType uint32
+}
+
+// UnlinkAtReq is used to make UnlinkAt request.
+type UnlinkAtReq struct {
+ DirFD FDID
+ Name SizedString
+ Flags primitive.Uint32
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (u *UnlinkAtReq) SizeBytes() int {
+ return u.DirFD.SizeBytes() + u.Name.SizeBytes() + u.Flags.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (u *UnlinkAtReq) MarshalBytes(dst []byte) {
+ u.DirFD.MarshalUnsafe(dst)
+ dst = dst[u.DirFD.SizeBytes():]
+ u.Name.MarshalBytes(dst)
+ dst = dst[u.Name.SizeBytes():]
+ u.Flags.MarshalUnsafe(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (u *UnlinkAtReq) UnmarshalBytes(src []byte) {
+ u.DirFD.UnmarshalUnsafe(src)
+ src = src[u.DirFD.SizeBytes():]
+ u.Name.UnmarshalBytes(src)
+ src = src[u.Name.SizeBytes():]
+ u.Flags.UnmarshalUnsafe(src)
+}
+
+// RenameAtReq is used to make Rename requests. Note that the request takes in
+// the to-be-renamed file's FD instead of oldDir and oldName like renameat(2).
+type RenameAtReq struct {
+ Renamed FDID
+ NewDir FDID
+ NewName SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (r *RenameAtReq) SizeBytes() int {
+ return r.Renamed.SizeBytes() + r.NewDir.SizeBytes() + r.NewName.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (r *RenameAtReq) MarshalBytes(dst []byte) {
+ r.Renamed.MarshalUnsafe(dst)
+ dst = dst[r.Renamed.SizeBytes():]
+ r.NewDir.MarshalUnsafe(dst)
+ dst = dst[r.NewDir.SizeBytes():]
+ r.NewName.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (r *RenameAtReq) UnmarshalBytes(src []byte) {
+ r.Renamed.UnmarshalUnsafe(src)
+ src = src[r.Renamed.SizeBytes():]
+ r.NewDir.UnmarshalUnsafe(src)
+ src = src[r.NewDir.SizeBytes():]
+ r.NewName.UnmarshalBytes(src)
+}
+
+// Getdents64Req is used to make Getdents64 requests.
+//
+// +marshal
+type Getdents64Req struct {
+ DirFD FDID
+ // Count is the number of bytes to read. A negative value of Count is used to
+ // indicate that the implementation must lseek(0, SEEK_SET) before calling
+ // getdents64(2). Implementations must use the absolute value of Count to
+ // determine the number of bytes to read.
+ Count int32
+}
+
+// Dirent64 is analogous to struct linux_dirent64.
+type Dirent64 struct {
+ Ino primitive.Uint64
+ DevMinor primitive.Uint32
+ DevMajor primitive.Uint32
+ Off primitive.Uint64
+ Type primitive.Uint8
+ Name SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (d *Dirent64) SizeBytes() int {
+ return d.Ino.SizeBytes() + d.DevMinor.SizeBytes() + d.DevMajor.SizeBytes() + d.Off.SizeBytes() + d.Type.SizeBytes() + d.Name.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (d *Dirent64) MarshalBytes(dst []byte) {
+ d.Ino.MarshalUnsafe(dst)
+ dst = dst[d.Ino.SizeBytes():]
+ d.DevMinor.MarshalUnsafe(dst)
+ dst = dst[d.DevMinor.SizeBytes():]
+ d.DevMajor.MarshalUnsafe(dst)
+ dst = dst[d.DevMajor.SizeBytes():]
+ d.Off.MarshalUnsafe(dst)
+ dst = dst[d.Off.SizeBytes():]
+ d.Type.MarshalUnsafe(dst)
+ dst = dst[d.Type.SizeBytes():]
+ d.Name.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (d *Dirent64) UnmarshalBytes(src []byte) {
+ d.Ino.UnmarshalUnsafe(src)
+ src = src[d.Ino.SizeBytes():]
+ d.DevMinor.UnmarshalUnsafe(src)
+ src = src[d.DevMinor.SizeBytes():]
+ d.DevMajor.UnmarshalUnsafe(src)
+ src = src[d.DevMajor.SizeBytes():]
+ d.Off.UnmarshalUnsafe(src)
+ src = src[d.Off.SizeBytes():]
+ d.Type.UnmarshalUnsafe(src)
+ src = src[d.Type.SizeBytes():]
+ d.Name.UnmarshalBytes(src)
+}
+
+// Getdents64Resp is used to communicate getdents64 results.
+type Getdents64Resp struct {
+ Dirents []Dirent64
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (g *Getdents64Resp) SizeBytes() int {
+ ret := (*primitive.Uint32)(nil).SizeBytes()
+ for i := range g.Dirents {
+ ret += g.Dirents[i].SizeBytes()
+ }
+ return ret
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (g *Getdents64Resp) MarshalBytes(dst []byte) {
+ numDirents := primitive.Uint32(len(g.Dirents))
+ numDirents.MarshalUnsafe(dst)
+ dst = dst[numDirents.SizeBytes():]
+ for i := range g.Dirents {
+ g.Dirents[i].MarshalBytes(dst)
+ dst = dst[g.Dirents[i].SizeBytes():]
+ }
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (g *Getdents64Resp) UnmarshalBytes(src []byte) {
+ var numDirents primitive.Uint32
+ numDirents.UnmarshalUnsafe(src)
+ if cap(g.Dirents) < int(numDirents) {
+ g.Dirents = make([]Dirent64, numDirents)
+ } else {
+ g.Dirents = g.Dirents[:numDirents]
+ }
+
+ src = src[numDirents.SizeBytes():]
+ for i := range g.Dirents {
+ g.Dirents[i].UnmarshalBytes(src)
+ src = src[g.Dirents[i].SizeBytes():]
+ }
+}
+
+// FGetXattrReq is used to make FGetXattr requests. The response to this is
+// just a SizedString containing the xattr value.
+type FGetXattrReq struct {
+ FD FDID
+ BufSize primitive.Uint32
+ Name SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (g *FGetXattrReq) SizeBytes() int {
+ return g.FD.SizeBytes() + g.BufSize.SizeBytes() + g.Name.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (g *FGetXattrReq) MarshalBytes(dst []byte) {
+ g.FD.MarshalUnsafe(dst)
+ dst = dst[g.FD.SizeBytes():]
+ g.BufSize.MarshalUnsafe(dst)
+ dst = dst[g.BufSize.SizeBytes():]
+ g.Name.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (g *FGetXattrReq) UnmarshalBytes(src []byte) {
+ g.FD.UnmarshalUnsafe(src)
+ src = src[g.FD.SizeBytes():]
+ g.BufSize.UnmarshalUnsafe(src)
+ src = src[g.BufSize.SizeBytes():]
+ g.Name.UnmarshalBytes(src)
+}
+
+// FGetXattrResp is used to respond to FGetXattr request.
+type FGetXattrResp struct {
+ Value SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (g *FGetXattrResp) SizeBytes() int {
+ return g.Value.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (g *FGetXattrResp) MarshalBytes(dst []byte) {
+ g.Value.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (g *FGetXattrResp) UnmarshalBytes(src []byte) {
+ g.Value.UnmarshalBytes(src)
+}
+
+// FSetXattrReq is used to make FSetXattr requests. It has no response.
+type FSetXattrReq struct {
+ FD FDID
+ Flags primitive.Uint32
+ Name SizedString
+ Value SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (s *FSetXattrReq) SizeBytes() int {
+ return s.FD.SizeBytes() + s.Flags.SizeBytes() + s.Name.SizeBytes() + s.Value.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (s *FSetXattrReq) MarshalBytes(dst []byte) {
+ s.FD.MarshalUnsafe(dst)
+ dst = dst[s.FD.SizeBytes():]
+ s.Flags.MarshalUnsafe(dst)
+ dst = dst[s.Flags.SizeBytes():]
+ s.Name.MarshalBytes(dst)
+ dst = dst[s.Name.SizeBytes():]
+ s.Value.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (s *FSetXattrReq) UnmarshalBytes(src []byte) {
+ s.FD.UnmarshalUnsafe(src)
+ src = src[s.FD.SizeBytes():]
+ s.Flags.UnmarshalUnsafe(src)
+ src = src[s.Flags.SizeBytes():]
+ s.Name.UnmarshalBytes(src)
+ src = src[s.Name.SizeBytes():]
+ s.Value.UnmarshalBytes(src)
+}
+
+// FRemoveXattrReq is used to make FRemoveXattr requests. It has no response.
+type FRemoveXattrReq struct {
+ FD FDID
+ Name SizedString
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (r *FRemoveXattrReq) SizeBytes() int {
+ return r.FD.SizeBytes() + r.Name.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (r *FRemoveXattrReq) MarshalBytes(dst []byte) {
+ r.FD.MarshalUnsafe(dst)
+ dst = dst[r.FD.SizeBytes():]
+ r.Name.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (r *FRemoveXattrReq) UnmarshalBytes(src []byte) {
+ r.FD.UnmarshalUnsafe(src)
+ src = src[r.FD.SizeBytes():]
+ r.Name.UnmarshalBytes(src)
+}
+
+// FListXattrReq is used to make FListXattr requests.
+//
+// +marshal
+type FListXattrReq struct {
+ FD FDID
+ _ uint32
+ Size uint64
+}
+
+// FListXattrResp is used to respond to FListXattr requests.
+type FListXattrResp struct {
+ Xattrs StringArray
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (l *FListXattrResp) SizeBytes() int {
+ return l.Xattrs.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (l *FListXattrResp) MarshalBytes(dst []byte) {
+ l.Xattrs.MarshalBytes(dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (l *FListXattrResp) UnmarshalBytes(src []byte) {
+ l.Xattrs.UnmarshalBytes(src)
+}
diff --git a/pkg/lisafs/sample_message.go b/pkg/lisafs/sample_message.go
new file mode 100644
index 000000000..3868dfa08
--- /dev/null
+++ b/pkg/lisafs/sample_message.go
@@ -0,0 +1,110 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "math/rand"
+
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
+)
+
+// MsgSimple is a sample packed struct which can be used to test message passing.
+//
+// +marshal slice:Msg1Slice
+type MsgSimple struct {
+ A uint16
+ B uint16
+ C uint32
+ D uint64
+}
+
+// Randomize randomizes the contents of m.
+func (m *MsgSimple) Randomize() {
+ m.A = uint16(rand.Uint32())
+ m.B = uint16(rand.Uint32())
+ m.C = rand.Uint32()
+ m.D = rand.Uint64()
+}
+
+// MsgDynamic is a sample dynamic struct which can be used to test message passing.
+//
+// +marshal dynamic
+type MsgDynamic struct {
+ N primitive.Uint32
+ Arr []MsgSimple
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (m *MsgDynamic) SizeBytes() int {
+ return m.N.SizeBytes() +
+ (int(m.N) * (*MsgSimple)(nil).SizeBytes())
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (m *MsgDynamic) MarshalBytes(dst []byte) {
+ m.N.MarshalUnsafe(dst)
+ dst = dst[m.N.SizeBytes():]
+ MarshalUnsafeMsg1Slice(m.Arr, dst)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (m *MsgDynamic) UnmarshalBytes(src []byte) {
+ m.N.UnmarshalUnsafe(src)
+ src = src[m.N.SizeBytes():]
+ m.Arr = make([]MsgSimple, m.N)
+ UnmarshalUnsafeMsg1Slice(m.Arr, src)
+}
+
+// Randomize randomizes the contents of m.
+func (m *MsgDynamic) Randomize(arrLen int) {
+ m.N = primitive.Uint32(arrLen)
+ m.Arr = make([]MsgSimple, arrLen)
+ for i := 0; i < arrLen; i++ {
+ m.Arr[i].Randomize()
+ }
+}
+
+// P9Version mimics p9.TVersion and p9.Rversion.
+//
+// +marshal dynamic
+type P9Version struct {
+ MSize primitive.Uint32
+ Version string
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (v *P9Version) SizeBytes() int {
+ return (*primitive.Uint32)(nil).SizeBytes() + (*primitive.Uint16)(nil).SizeBytes() + len(v.Version)
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (v *P9Version) MarshalBytes(dst []byte) {
+ v.MSize.MarshalUnsafe(dst)
+ dst = dst[v.MSize.SizeBytes():]
+ versionLen := primitive.Uint16(len(v.Version))
+ versionLen.MarshalUnsafe(dst)
+ dst = dst[versionLen.SizeBytes():]
+ copy(dst, v.Version)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (v *P9Version) UnmarshalBytes(src []byte) {
+ v.MSize.UnmarshalUnsafe(src)
+ src = src[v.MSize.SizeBytes():]
+ var versionLen primitive.Uint16
+ versionLen.UnmarshalUnsafe(src)
+ src = src[versionLen.SizeBytes():]
+ v.Version = string(src[:versionLen])
+}
diff --git a/pkg/lisafs/server.go b/pkg/lisafs/server.go
new file mode 100644
index 000000000..7515355ec
--- /dev/null
+++ b/pkg/lisafs/server.go
@@ -0,0 +1,113 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Server serves a filesystem tree. Multiple connections on different mount
+// points can be started on a server. The server provides utilities to safely
+// modify the filesystem tree across its connections (mount points). Note that
+// it does not support synchronizing filesystem tree mutations across other
+// servers serving the same filesystem subtree. Server also manages the
+// lifecycle of all connections.
+type Server struct {
+ // connWg counts the number of active connections being tracked.
+ connWg sync.WaitGroup
+
+ // RenameMu synchronizes rename operations within this filesystem tree.
+ RenameMu sync.RWMutex
+
+ // handlers is a list of RPC handlers which can be indexed by the handler's
+ // corresponding MID.
+ handlers []RPCHandler
+
+ // mountPoints keeps track of all the mount points this server serves.
+ mpMu sync.RWMutex
+ mountPoints []*ControlFD
+
+ // impl is the server implementation which embeds this server.
+ impl ServerImpl
+}
+
+// Init must be called before first use of server.
+func (s *Server) Init(impl ServerImpl) {
+ s.impl = impl
+ s.handlers = handlers[:]
+}
+
+// InitTestOnly is the same as Init except that it allows to swap out the
+// underlying handlers with something custom. This is for test only.
+func (s *Server) InitTestOnly(impl ServerImpl, handlers []RPCHandler) {
+ s.impl = impl
+ s.handlers = handlers
+}
+
+// WithRenameReadLock invokes fn with the server's rename mutex locked for
+// reading. This ensures that no rename operations occur concurrently.
+func (s *Server) WithRenameReadLock(fn func() error) error {
+ s.RenameMu.RLock()
+ err := fn()
+ s.RenameMu.RUnlock()
+ return err
+}
+
+// StartConnection starts the connection on a separate goroutine and tracks it.
+func (s *Server) StartConnection(c *Connection) {
+ s.connWg.Add(1)
+ go func() {
+ c.Run()
+ s.connWg.Done()
+ }()
+}
+
+// Wait waits for all connections started via StartConnection() to terminate.
+func (s *Server) Wait() {
+ s.connWg.Wait()
+}
+
+func (s *Server) addMountPoint(root *ControlFD) {
+ s.mpMu.Lock()
+ defer s.mpMu.Unlock()
+ s.mountPoints = append(s.mountPoints, root)
+}
+
+func (s *Server) forEachMountPoint(fn func(root *ControlFD)) {
+ s.mpMu.RLock()
+ defer s.mpMu.RUnlock()
+ for _, mp := range s.mountPoints {
+ fn(mp)
+ }
+}
+
+// ServerImpl contains the implementation details for a Server.
+// Implementations of ServerImpl should contain their associated Server by
+// value as their first field.
+type ServerImpl interface {
+ // Mount is called when a Mount RPC is made. It mounts the connection at
+ // mountPath.
+ //
+ // Precondition: mountPath == path.Clean(mountPath).
+ Mount(c *Connection, mountPath string) (ControlFDImpl, Inode, error)
+
+ // SupportedMessages returns a list of messages that the server
+ // implementation supports.
+ SupportedMessages() []MID
+
+ // MaxMessageSize is the maximum payload length (in bytes) that can be sent
+ // to this server implementation.
+ MaxMessageSize() uint32
+}
diff --git a/pkg/lisafs/sock.go b/pkg/lisafs/sock.go
new file mode 100644
index 000000000..88210242f
--- /dev/null
+++ b/pkg/lisafs/sock.go
@@ -0,0 +1,208 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "io"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+var (
+ sockHeaderLen = uint32((*sockHeader)(nil).SizeBytes())
+)
+
+// sockHeader is the header present in front of each message received on a UDS.
+//
+// +marshal
+type sockHeader struct {
+ payloadLen uint32
+ message MID
+ _ uint16 // Need to make struct packed.
+}
+
+// sockCommunicator implements Communicator. This is not thread safe.
+type sockCommunicator struct {
+ fdTracker
+ sock *unet.Socket
+ buf []byte
+}
+
+var _ Communicator = (*sockCommunicator)(nil)
+
+func newSockComm(sock *unet.Socket) *sockCommunicator {
+ return &sockCommunicator{
+ sock: sock,
+ buf: make([]byte, sockHeaderLen),
+ }
+}
+
+func (s *sockCommunicator) FD() int {
+ return s.sock.FD()
+}
+
+func (s *sockCommunicator) destroy() {
+ s.sock.Close()
+}
+
+func (s *sockCommunicator) shutdown() {
+ if err := s.sock.Shutdown(); err != nil {
+ log.Warningf("Socket.Shutdown() failed (FD: %d): %v", s.sock.FD(), err)
+ }
+}
+
+func (s *sockCommunicator) resizeBuf(size uint32) {
+ if cap(s.buf) < int(size) {
+ s.buf = s.buf[:cap(s.buf)]
+ s.buf = append(s.buf, make([]byte, int(size)-cap(s.buf))...)
+ } else {
+ s.buf = s.buf[:size]
+ }
+}
+
+// PayloadBuf implements Communicator.PayloadBuf.
+func (s *sockCommunicator) PayloadBuf(size uint32) []byte {
+ s.resizeBuf(sockHeaderLen + size)
+ return s.buf[sockHeaderLen : sockHeaderLen+size]
+}
+
+// SndRcvMessage implements Communicator.SndRcvMessage.
+func (s *sockCommunicator) SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) {
+ if err := s.sndPrepopulatedMsg(m, payloadLen, nil); err != nil {
+ return 0, 0, err
+ }
+
+ return s.rcvMsg(wantFDs)
+}
+
+// sndPrepopulatedMsg assumes that s.buf has already been populated with
+// `payloadLen` bytes of data.
+func (s *sockCommunicator) sndPrepopulatedMsg(m MID, payloadLen uint32, fds []int) error {
+ header := sockHeader{payloadLen: payloadLen, message: m}
+ header.MarshalUnsafe(s.buf)
+ dataLen := sockHeaderLen + payloadLen
+ return writeTo(s.sock, [][]byte{s.buf[:dataLen]}, int(dataLen), fds)
+}
+
+// writeTo writes the passed iovec to the UDS and donates any passed FDs.
+func writeTo(sock *unet.Socket, iovec [][]byte, dataLen int, fds []int) error {
+ w := sock.Writer(true)
+ if len(fds) > 0 {
+ w.PackFDs(fds...)
+ }
+
+ fdsUnpacked := false
+ for n := 0; n < dataLen; {
+ cur, err := w.WriteVec(iovec)
+ if err != nil {
+ return err
+ }
+ n += cur
+
+ // Fast common path.
+ if n >= dataLen {
+ break
+ }
+
+ // Consume iovecs.
+ for consumed := 0; consumed < cur; {
+ if len(iovec[0]) <= cur-consumed {
+ consumed += len(iovec[0])
+ iovec = iovec[1:]
+ } else {
+ iovec[0] = iovec[0][cur-consumed:]
+ break
+ }
+ }
+
+ if n > 0 && !fdsUnpacked {
+ // Don't resend any control message.
+ fdsUnpacked = true
+ w.UnpackFDs()
+ }
+ }
+ return nil
+}
+
+// rcvMsg reads the message header and payload from the UDS. It also populates
+// fds with any donated FDs.
+func (s *sockCommunicator) rcvMsg(wantFDs uint8) (MID, uint32, error) {
+ fds, err := readFrom(s.sock, s.buf[:sockHeaderLen], wantFDs)
+ if err != nil {
+ return 0, 0, err
+ }
+ for _, fd := range fds {
+ s.TrackFD(fd)
+ }
+
+ var header sockHeader
+ header.UnmarshalUnsafe(s.buf)
+
+ // No payload? We are done.
+ if header.payloadLen == 0 {
+ return header.message, 0, nil
+ }
+
+ if _, err := readFrom(s.sock, s.PayloadBuf(header.payloadLen), 0); err != nil {
+ return 0, 0, err
+ }
+
+ return header.message, header.payloadLen, nil
+}
+
+// readFrom fills the passed buffer with data from the socket. It also returns
+// any donated FDs.
+func readFrom(sock *unet.Socket, buf []byte, wantFDs uint8) ([]int, error) {
+ r := sock.Reader(true)
+ r.EnableFDs(int(wantFDs))
+
+ var (
+ fds []int
+ fdInit bool
+ )
+ n := len(buf)
+ for got := 0; got < n; {
+ cur, err := r.ReadVec([][]byte{buf[got:]})
+
+ // Ignore EOF if cur > 0.
+ if err != nil && (err != io.EOF || cur == 0) {
+ r.CloseFDs()
+ return nil, err
+ }
+
+ if !fdInit && cur > 0 {
+ fds, err = r.ExtractFDs()
+ if err != nil {
+ return nil, err
+ }
+
+ fdInit = true
+ r.EnableFDs(0)
+ }
+
+ got += cur
+ }
+ return fds, nil
+}
+
+func closeFDs(fds []int) {
+ for _, fd := range fds {
+ if fd >= 0 {
+ unix.Close(fd)
+ }
+ }
+}
diff --git a/pkg/lisafs/sock_test.go b/pkg/lisafs/sock_test.go
new file mode 100644
index 000000000..387f4b7a8
--- /dev/null
+++ b/pkg/lisafs/sock_test.go
@@ -0,0 +1,217 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs
+
+import (
+ "bytes"
+ "math/rand"
+ "reflect"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+func runSocketTest(t *testing.T, fun1 func(*sockCommunicator), fun2 func(*sockCommunicator)) {
+ sock1, sock2, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer sock1.Close()
+ defer sock2.Close()
+
+ var testWg sync.WaitGroup
+ testWg.Add(2)
+
+ go func() {
+ fun1(newSockComm(sock1))
+ testWg.Done()
+ }()
+
+ go func() {
+ fun2(newSockComm(sock2))
+ testWg.Done()
+ }()
+
+ testWg.Wait()
+}
+
+func TestReadWrite(t *testing.T) {
+ // Create random data to send.
+ n := 10000
+ data := make([]byte, n)
+ if _, err := rand.Read(data); err != nil {
+ t.Fatalf("rand.Read(data) failed: %v", err)
+ }
+
+ runSocketTest(t, func(comm *sockCommunicator) {
+ // Scatter that data into two parts using Iovecs while sending.
+ mid := n / 2
+ if err := writeTo(comm.sock, [][]byte{data[:mid], data[mid:]}, n, nil); err != nil {
+ t.Errorf("writeTo socket failed: %v", err)
+ }
+ }, func(comm *sockCommunicator) {
+ gotData := make([]byte, n)
+ if _, err := readFrom(comm.sock, gotData, 0); err != nil {
+ t.Fatalf("reading from socket failed: %v", err)
+ }
+
+ // Make sure we got the right data.
+ if res := bytes.Compare(data, gotData); res != 0 {
+ t.Errorf("data received differs from data sent, want = %v, got = %v", data, gotData)
+ }
+ })
+}
+
+func TestFDDonation(t *testing.T) {
+ n := 10
+ data := make([]byte, n)
+ if _, err := rand.Read(data); err != nil {
+ t.Fatalf("rand.Read(data) failed: %v", err)
+ }
+
+ // Try donating FDs to these files.
+ path1 := "/dev/null"
+ path2 := "/dev"
+ path3 := "/dev/random"
+
+ runSocketTest(t, func(comm *sockCommunicator) {
+ devNullFD, err := unix.Open(path1, unix.O_RDONLY, 0)
+ defer unix.Close(devNullFD)
+ if err != nil {
+ t.Fatalf("open(%s) failed: %v", path1, err)
+ }
+ devFD, err := unix.Open(path2, unix.O_RDONLY, 0)
+ defer unix.Close(devFD)
+ if err != nil {
+ t.Fatalf("open(%s) failed: %v", path2, err)
+ }
+ devRandomFD, err := unix.Open(path3, unix.O_RDONLY, 0)
+ defer unix.Close(devRandomFD)
+ if err != nil {
+ t.Fatalf("open(%s) failed: %v", path2, err)
+ }
+ if err := writeTo(comm.sock, [][]byte{data}, n, []int{devNullFD, devFD, devRandomFD}); err != nil {
+ t.Errorf("writeTo socket failed: %v", err)
+ }
+ }, func(comm *sockCommunicator) {
+ gotData := make([]byte, n)
+ fds, err := readFrom(comm.sock, gotData, 3)
+ if err != nil {
+ t.Fatalf("reading from socket failed: %v", err)
+ }
+ defer closeFDs(fds[:])
+
+ if res := bytes.Compare(data, gotData); res != 0 {
+ t.Errorf("data received differs from data sent, want = %v, got = %v", data, gotData)
+ }
+
+ if len(fds) != 3 {
+ t.Fatalf("wanted 3 FD, got %d", len(fds))
+ }
+
+ // Check that the FDs actually point to the correct file.
+ compareFDWithFile(t, fds[0], path1)
+ compareFDWithFile(t, fds[1], path2)
+ compareFDWithFile(t, fds[2], path3)
+ })
+}
+
+func compareFDWithFile(t *testing.T, fd int, path string) {
+ var want unix.Stat_t
+ if err := unix.Stat(path, &want); err != nil {
+ t.Fatalf("stat(%s) failed: %v", path, err)
+ }
+
+ var got unix.Stat_t
+ if err := unix.Fstat(fd, &got); err != nil {
+ t.Fatalf("fstat on donated FD failed: %v", err)
+ }
+
+ if got.Ino != want.Ino || got.Dev != want.Dev {
+ t.Errorf("FD does not point to %s, want = %+v, got = %+v", path, want, got)
+ }
+}
+
+func testSndMsg(comm *sockCommunicator, m MID, msg marshal.Marshallable) error {
+ var payloadLen uint32
+ if msg != nil {
+ payloadLen = uint32(msg.SizeBytes())
+ msg.MarshalUnsafe(comm.PayloadBuf(payloadLen))
+ }
+ return comm.sndPrepopulatedMsg(m, payloadLen, nil)
+}
+
+func TestSndRcvMessage(t *testing.T) {
+ req := &MsgSimple{}
+ req.Randomize()
+ reqM := MID(1)
+
+ // Create a massive random response.
+ var resp MsgDynamic
+ resp.Randomize(100)
+ respM := MID(2)
+
+ runSocketTest(t, func(comm *sockCommunicator) {
+ if err := testSndMsg(comm, reqM, req); err != nil {
+ t.Errorf("writeMessageTo failed: %v", err)
+ }
+ checkMessageReceive(t, comm, respM, &resp)
+ }, func(comm *sockCommunicator) {
+ checkMessageReceive(t, comm, reqM, req)
+ if err := testSndMsg(comm, respM, &resp); err != nil {
+ t.Errorf("writeMessageTo failed: %v", err)
+ }
+ })
+}
+
+func TestSndRcvMessageNoPayload(t *testing.T) {
+ reqM := MID(1)
+ respM := MID(2)
+ runSocketTest(t, func(comm *sockCommunicator) {
+ if err := testSndMsg(comm, reqM, nil); err != nil {
+ t.Errorf("writeMessageTo failed: %v", err)
+ }
+ checkMessageReceive(t, comm, respM, nil)
+ }, func(comm *sockCommunicator) {
+ checkMessageReceive(t, comm, reqM, nil)
+ if err := testSndMsg(comm, respM, nil); err != nil {
+ t.Errorf("writeMessageTo failed: %v", err)
+ }
+ })
+}
+
+func checkMessageReceive(t *testing.T, comm *sockCommunicator, wantM MID, wantMsg marshal.Marshallable) {
+ gotM, payloadLen, err := comm.rcvMsg(0)
+ if err != nil {
+ t.Fatalf("readMessageFrom failed: %v", err)
+ }
+ if gotM != wantM {
+ t.Errorf("got incorrect message ID: got = %d, want = %d", gotM, wantM)
+ }
+ if wantMsg == nil {
+ if payloadLen != 0 {
+ t.Errorf("no payload expect but got %d bytes", payloadLen)
+ }
+ } else {
+ gotMsg := reflect.New(reflect.ValueOf(wantMsg).Elem().Type()).Interface().(marshal.Marshallable)
+ gotMsg.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
+ if !reflect.DeepEqual(wantMsg, gotMsg) {
+ t.Errorf("msg differs: want = %+v, got = %+v", wantMsg, gotMsg)
+ }
+ }
+}
diff --git a/pkg/lisafs/testsuite/BUILD b/pkg/lisafs/testsuite/BUILD
new file mode 100644
index 000000000..b4a542b3a
--- /dev/null
+++ b/pkg/lisafs/testsuite/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"],
+)
+
+go_library(
+ name = "testsuite",
+ testonly = True,
+ srcs = ["testsuite.go"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/lisafs",
+ "//pkg/unet",
+ "@com_github_syndtr_gocapability//capability:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/lisafs/testsuite/testsuite.go b/pkg/lisafs/testsuite/testsuite.go
new file mode 100644
index 000000000..5fc7c364d
--- /dev/null
+++ b/pkg/lisafs/testsuite/testsuite.go
@@ -0,0 +1,637 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package testsuite provides a integration testing suite for lisafs.
+// These tests are intended for servers serving the local filesystem.
+package testsuite
+
+import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "math/rand"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/syndtr/gocapability/capability"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/lisafs"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// Tester is the client code using this test suite. This interface abstracts
+// away all the caller specific details.
+type Tester interface {
+ // NewServer returns a new instance of the tester server.
+ NewServer(t *testing.T) *lisafs.Server
+
+ // LinkSupported returns true if the backing server supports LinkAt.
+ LinkSupported() bool
+
+ // SetUserGroupIDSupported returns true if the backing server supports
+ // changing UID/GID for files.
+ SetUserGroupIDSupported() bool
+}
+
+// RunAllLocalFSTests runs all local FS tests as subtests.
+func RunAllLocalFSTests(t *testing.T, tester Tester) {
+ for name, testFn := range localFSTests {
+ t.Run(name, func(t *testing.T) {
+ runServerClient(t, tester, testFn)
+ })
+ }
+}
+
+type testFunc func(context.Context, *testing.T, Tester, lisafs.ClientFD)
+
+var localFSTests map[string]testFunc = map[string]testFunc{
+ "Stat": testStat,
+ "RegularFileIO": testRegularFileIO,
+ "RegularFileOpen": testRegularFileOpen,
+ "SetStat": testSetStat,
+ "Allocate": testAllocate,
+ "StatFS": testStatFS,
+ "Unlink": testUnlink,
+ "Symlink": testSymlink,
+ "HardLink": testHardLink,
+ "Walk": testWalk,
+ "Rename": testRename,
+ "Mknod": testMknod,
+ "Getdents": testGetdents,
+}
+
+func runServerClient(t *testing.T, tester Tester, testFn testFunc) {
+ mountPath, err := ioutil.TempDir(os.Getenv("TEST_TMPDIR"), "")
+ if err != nil {
+ t.Fatalf("creation of temporary mountpoint failed: %v", err)
+ }
+ defer os.RemoveAll(mountPath)
+
+ // fsgofer should run with a umask of 0, because we want to preserve file
+ // modes exactly for testing purposes.
+ unix.Umask(0)
+
+ serverSocket, clientSocket, err := unet.SocketPair(false)
+ if err != nil {
+ t.Fatalf("socketpair got err %v expected nil", err)
+ }
+
+ server := tester.NewServer(t)
+ conn, err := server.CreateConnection(serverSocket, false /* readonly */)
+ if err != nil {
+ t.Fatalf("starting connection failed: %v", err)
+ return
+ }
+ server.StartConnection(conn)
+
+ c, root, err := lisafs.NewClient(clientSocket, mountPath)
+ if err != nil {
+ t.Fatalf("client creation failed: %v", err)
+ }
+
+ if !root.ControlFD.Ok() {
+ t.Fatalf("root control FD is not valid")
+ }
+ rootFile := c.NewFD(root.ControlFD)
+
+ ctx := context.Background()
+ testFn(ctx, t, tester, rootFile)
+ closeFD(ctx, t, rootFile)
+
+ c.Close() // This should trigger client and server shutdown.
+ server.Wait()
+}
+
+func closeFD(ctx context.Context, t testing.TB, fdLisa lisafs.ClientFD) {
+ if err := fdLisa.Close(ctx); err != nil {
+ t.Errorf("failed to close FD: %v", err)
+ }
+}
+
+func statTo(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, stat *linux.Statx) {
+ if err := fdLisa.StatTo(ctx, stat); err != nil {
+ t.Fatalf("stat failed: %v", err)
+ }
+}
+
+func openCreateFile(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, name string) (lisafs.ClientFD, linux.Statx, lisafs.ClientFD, int) {
+ child, childFD, childHostFD, err := fdLisa.OpenCreateAt(ctx, name, unix.O_RDWR, 0777, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid()))
+ if err != nil {
+ t.Fatalf("OpenCreateAt failed: %v", err)
+ }
+ if childHostFD == -1 {
+ t.Error("no host FD donated")
+ }
+ client := fdLisa.Client()
+ return client.NewFD(child.ControlFD), child.Stat, fdLisa.Client().NewFD(childFD), childHostFD
+}
+
+func openFile(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, flags uint32, isReg bool) (lisafs.ClientFD, int) {
+ newFD, hostFD, err := fdLisa.OpenAt(ctx, flags)
+ if err != nil {
+ t.Fatalf("OpenAt failed: %v", err)
+ }
+ if hostFD == -1 && isReg {
+ t.Error("no host FD donated")
+ }
+ return fdLisa.Client().NewFD(newFD), hostFD
+}
+
+func unlinkFile(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string, isDir bool) {
+ var flags uint32
+ if isDir {
+ flags = unix.AT_REMOVEDIR
+ }
+ if err := dir.UnlinkAt(ctx, name, flags); err != nil {
+ t.Errorf("unlinking file %s failed: %v", name, err)
+ }
+}
+
+func symlink(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name, target string) (lisafs.ClientFD, linux.Statx) {
+ linkIno, err := dir.SymlinkAt(ctx, name, target, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid()))
+ if err != nil {
+ t.Fatalf("symlink failed: %v", err)
+ }
+ return dir.Client().NewFD(linkIno.ControlFD), linkIno.Stat
+}
+
+func link(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string, target lisafs.ClientFD) (lisafs.ClientFD, linux.Statx) {
+ linkIno, err := dir.LinkAt(ctx, target.ID(), name)
+ if err != nil {
+ t.Fatalf("link failed: %v", err)
+ }
+ return dir.Client().NewFD(linkIno.ControlFD), linkIno.Stat
+}
+
+func mkdir(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string) (lisafs.ClientFD, linux.Statx) {
+ childIno, err := dir.MkdirAt(ctx, name, 0777, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid()))
+ if err != nil {
+ t.Fatalf("mkdir failed: %v", err)
+ }
+ return dir.Client().NewFD(childIno.ControlFD), childIno.Stat
+}
+
+func mknod(ctx context.Context, t *testing.T, dir lisafs.ClientFD, name string) (lisafs.ClientFD, linux.Statx) {
+ nodeIno, err := dir.MknodAt(ctx, name, unix.S_IFREG|0777, lisafs.UID(unix.Getuid()), lisafs.GID(unix.Getgid()), 0, 0)
+ if err != nil {
+ t.Fatalf("mknod failed: %v", err)
+ }
+ return dir.Client().NewFD(nodeIno.ControlFD), nodeIno.Stat
+}
+
+func walk(ctx context.Context, t *testing.T, dir lisafs.ClientFD, names []string) []lisafs.Inode {
+ _, inodes, err := dir.WalkMultiple(ctx, names)
+ if err != nil {
+ t.Fatalf("walk failed while trying to walk components %+v: %v", names, err)
+ }
+ return inodes
+}
+
+func walkStat(ctx context.Context, t *testing.T, dir lisafs.ClientFD, names []string) []linux.Statx {
+ stats, err := dir.WalkStat(ctx, names)
+ if err != nil {
+ t.Fatalf("walk failed while trying to walk components %+v: %v", names, err)
+ }
+ return stats
+}
+
+func writeFD(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, off uint64, buf []byte) error {
+ count, err := fdLisa.Write(ctx, buf, off)
+ if err != nil {
+ return err
+ }
+ if int(count) != len(buf) {
+ t.Errorf("partial write: buf size = %d, written = %d", len(buf), count)
+ }
+ return nil
+}
+
+func readFDAndCmp(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, off uint64, want []byte) {
+ buf := make([]byte, len(want))
+ n, err := fdLisa.Read(ctx, buf, off)
+ if err != nil {
+ t.Errorf("read failed: %v", err)
+ return
+ }
+ if int(n) != len(want) {
+ t.Errorf("partial read: buf size = %d, read = %d", len(want), n)
+ return
+ }
+ if bytes.Compare(buf, want) != 0 {
+ t.Errorf("bytes read differ from what was expected: want = %v, got = %v", want, buf)
+ }
+}
+
+func allocateAndVerify(ctx context.Context, t *testing.T, fdLisa lisafs.ClientFD, off uint64, length uint64) {
+ if err := fdLisa.Allocate(ctx, 0, off, length); err != nil {
+ t.Fatalf("fallocate failed: %v", err)
+ }
+
+ var stat linux.Statx
+ statTo(ctx, t, fdLisa, &stat)
+ if want := off + length; stat.Size != want {
+ t.Errorf("incorrect file size after allocate: expected %d, got %d", off+length, stat.Size)
+ }
+}
+
+func cmpStatx(t *testing.T, want, got linux.Statx) {
+ if got.Mask&unix.STATX_MODE != 0 && want.Mask&unix.STATX_MODE != 0 {
+ if got.Mode != want.Mode {
+ t.Errorf("mode differs: want %d, got %d", want.Mode, got.Mode)
+ }
+ }
+ if got.Mask&unix.STATX_INO != 0 && want.Mask&unix.STATX_INO != 0 {
+ if got.Ino != want.Ino {
+ t.Errorf("inode number differs: want %d, got %d", want.Ino, got.Ino)
+ }
+ }
+ if got.Mask&unix.STATX_NLINK != 0 && want.Mask&unix.STATX_NLINK != 0 {
+ if got.Nlink != want.Nlink {
+ t.Errorf("nlink differs: want %d, got %d", want.Nlink, got.Nlink)
+ }
+ }
+ if got.Mask&unix.STATX_UID != 0 && want.Mask&unix.STATX_UID != 0 {
+ if got.UID != want.UID {
+ t.Errorf("UID differs: want %d, got %d", want.UID, got.UID)
+ }
+ }
+ if got.Mask&unix.STATX_GID != 0 && want.Mask&unix.STATX_GID != 0 {
+ if got.GID != want.GID {
+ t.Errorf("GID differs: want %d, got %d", want.GID, got.GID)
+ }
+ }
+ if got.Mask&unix.STATX_SIZE != 0 && want.Mask&unix.STATX_SIZE != 0 {
+ if got.Size != want.Size {
+ t.Errorf("size differs: want %d, got %d", want.Size, got.Size)
+ }
+ }
+ if got.Mask&unix.STATX_BLOCKS != 0 && want.Mask&unix.STATX_BLOCKS != 0 {
+ if got.Blocks != want.Blocks {
+ t.Errorf("blocks differs: want %d, got %d", want.Blocks, got.Blocks)
+ }
+ }
+ if got.Mask&unix.STATX_ATIME != 0 && want.Mask&unix.STATX_ATIME != 0 {
+ if got.Atime != want.Atime {
+ t.Errorf("atime differs: want %d, got %d", want.Atime, got.Atime)
+ }
+ }
+ if got.Mask&unix.STATX_MTIME != 0 && want.Mask&unix.STATX_MTIME != 0 {
+ if got.Mtime != want.Mtime {
+ t.Errorf("mtime differs: want %d, got %d", want.Mtime, got.Mtime)
+ }
+ }
+ if got.Mask&unix.STATX_CTIME != 0 && want.Mask&unix.STATX_CTIME != 0 {
+ if got.Ctime != want.Ctime {
+ t.Errorf("ctime differs: want %d, got %d", want.Ctime, got.Ctime)
+ }
+ }
+}
+
+func hasCapability(c capability.Cap) bool {
+ caps, err := capability.NewPid2(os.Getpid())
+ if err != nil {
+ return false
+ }
+ if err := caps.Load(); err != nil {
+ return false
+ }
+ return caps.Get(capability.EFFECTIVE, c)
+}
+
+func testStat(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ var rootStat linux.Statx
+ if err := root.StatTo(ctx, &rootStat); err != nil {
+ t.Errorf("stat on the root dir failed: %v", err)
+ }
+
+ if ftype := rootStat.Mode & unix.S_IFMT; ftype != unix.S_IFDIR {
+ t.Errorf("root inode is not a directory, file type = %d", ftype)
+ }
+}
+
+func testRegularFileIO(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ name := "tempFile"
+ controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name)
+ defer closeFD(ctx, t, controlFile)
+ defer closeFD(ctx, t, fd)
+ defer unix.Close(hostFD)
+
+ // Test Read/Write RPCs with 2MB of data to test IO in chunks.
+ data := make([]byte, 1<<21)
+ rand.Read(data)
+ if err := writeFD(ctx, t, fd, 0, data); err != nil {
+ t.Fatalf("write failed: %v", err)
+ }
+ readFDAndCmp(ctx, t, fd, 0, data)
+ readFDAndCmp(ctx, t, fd, 50, data[50:])
+
+ // Make sure the host FD is configured properly.
+ hostReadData := make([]byte, len(data))
+ if n, err := unix.Pread(hostFD, hostReadData, 0); err != nil {
+ t.Errorf("host read failed: %v", err)
+ } else if n != len(hostReadData) {
+ t.Errorf("partial read: buf size = %d, read = %d", len(hostReadData), n)
+ } else if bytes.Compare(hostReadData, data) != 0 {
+ t.Errorf("bytes read differ from what was expected: want = %v, got = %v", data, hostReadData)
+ }
+
+ // Test syncing the writable FD.
+ if err := fd.Sync(ctx); err != nil {
+ t.Errorf("syncing the FD failed: %v", err)
+ }
+}
+
+func testRegularFileOpen(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ name := "tempFile"
+ controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name)
+ defer closeFD(ctx, t, controlFile)
+ defer closeFD(ctx, t, fd)
+ defer unix.Close(hostFD)
+
+ // Open a readonly FD and try writing to it to get an EBADF.
+ roFile, roHostFD := openFile(ctx, t, controlFile, unix.O_RDONLY, true /* isReg */)
+ defer closeFD(ctx, t, roFile)
+ defer unix.Close(roHostFD)
+ if err := writeFD(ctx, t, roFile, 0, []byte{1, 2, 3}); err != unix.EBADF {
+ t.Errorf("writing to read only FD should generate EBADF, but got %v", err)
+ }
+}
+
+func testSetStat(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ name := "tempFile"
+ controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name)
+ defer closeFD(ctx, t, controlFile)
+ defer closeFD(ctx, t, fd)
+ defer unix.Close(hostFD)
+
+ now := time.Now()
+ wantStat := linux.Statx{
+ Mask: unix.STATX_MODE | unix.STATX_ATIME | unix.STATX_MTIME | unix.STATX_SIZE,
+ Mode: 0760,
+ UID: uint32(unix.Getuid()),
+ GID: uint32(unix.Getgid()),
+ Size: 50,
+ Atime: linux.NsecToStatxTimestamp(now.UnixNano()),
+ Mtime: linux.NsecToStatxTimestamp(now.UnixNano()),
+ }
+ if tester.SetUserGroupIDSupported() {
+ wantStat.Mask |= unix.STATX_UID | unix.STATX_GID
+ }
+ failureMask, failureErr, err := controlFile.SetStat(ctx, &wantStat)
+ if err != nil {
+ t.Fatalf("setstat failed: %v", err)
+ }
+ if failureMask != 0 {
+ t.Fatalf("some setstat operations failed: failureMask = %#b, failureErr = %v", failureMask, failureErr)
+ }
+
+ // Verify that attributes were updated.
+ var gotStat linux.Statx
+ statTo(ctx, t, controlFile, &gotStat)
+ if gotStat.Mode&07777 != wantStat.Mode ||
+ gotStat.Size != wantStat.Size ||
+ gotStat.Atime.ToNsec() != wantStat.Atime.ToNsec() ||
+ gotStat.Mtime.ToNsec() != wantStat.Mtime.ToNsec() ||
+ (tester.SetUserGroupIDSupported() && (uint32(gotStat.UID) != wantStat.UID || uint32(gotStat.GID) != wantStat.GID)) {
+ t.Errorf("setStat did not update file correctly: setStat = %+v, stat = %+v", wantStat, gotStat)
+ }
+}
+
+func testAllocate(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ name := "tempFile"
+ controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name)
+ defer closeFD(ctx, t, controlFile)
+ defer closeFD(ctx, t, fd)
+ defer unix.Close(hostFD)
+
+ allocateAndVerify(ctx, t, fd, 0, 40)
+ allocateAndVerify(ctx, t, fd, 20, 100)
+}
+
+func testStatFS(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ var statFS lisafs.StatFS
+ if err := root.StatFSTo(ctx, &statFS); err != nil {
+ t.Errorf("statfs failed: %v", err)
+ }
+}
+
+func testUnlink(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ name := "tempFile"
+ controlFile, _, fd, hostFD := openCreateFile(ctx, t, root, name)
+ defer closeFD(ctx, t, controlFile)
+ defer closeFD(ctx, t, fd)
+ defer unix.Close(hostFD)
+
+ unlinkFile(ctx, t, root, name, false /* isDir */)
+ if inodes := walk(ctx, t, root, []string{name}); len(inodes) > 0 {
+ t.Errorf("deleted file should not be generating inodes on walk: %+v", inodes)
+ }
+}
+
+func testSymlink(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ target := "/tmp/some/path"
+ name := "symlinkFile"
+ link, linkStat := symlink(ctx, t, root, name, target)
+ defer closeFD(ctx, t, link)
+
+ if linkStat.Mode&unix.S_IFMT != unix.S_IFLNK {
+ t.Errorf("stat return from symlink RPC indicates that the inode is not a symlink: mode = %d", linkStat.Mode)
+ }
+
+ if gotTarget, err := link.ReadLinkAt(ctx); err != nil {
+ t.Fatalf("readlink failed: %v", err)
+ } else if gotTarget != target {
+ t.Errorf("readlink return incorrect target: expected %q, got %q", target, gotTarget)
+ }
+}
+
+func testHardLink(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ if !tester.LinkSupported() {
+ t.Skipf("server does not support LinkAt RPC")
+ }
+ if !hasCapability(capability.CAP_DAC_READ_SEARCH) {
+ t.Skipf("TestHardLink requires CAP_DAC_READ_SEARCH, running as %d", unix.Getuid())
+ }
+ name := "tempFile"
+ controlFile, fileIno, fd, hostFD := openCreateFile(ctx, t, root, name)
+ defer closeFD(ctx, t, controlFile)
+ defer closeFD(ctx, t, fd)
+ defer unix.Close(hostFD)
+
+ link, linkStat := link(ctx, t, root, name, controlFile)
+ defer closeFD(ctx, t, link)
+
+ if linkStat.Ino != fileIno.Ino {
+ t.Errorf("hard linked files have different inode numbers: %d %d", linkStat.Ino, fileIno.Ino)
+ }
+ if linkStat.DevMinor != fileIno.DevMinor {
+ t.Errorf("hard linked files have different minor device numbers: %d %d", linkStat.DevMinor, fileIno.DevMinor)
+ }
+ if linkStat.DevMajor != fileIno.DevMajor {
+ t.Errorf("hard linked files have different major device numbers: %d %d", linkStat.DevMajor, fileIno.DevMajor)
+ }
+}
+
+func testWalk(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ // Create 10 nested directories.
+ n := 10
+ curDir := root
+
+ dirNames := make([]string, 0, n)
+ for i := 0; i < n; i++ {
+ name := fmt.Sprintf("tmpdir-%d", i)
+ childDir, _ := mkdir(ctx, t, curDir, name)
+ defer closeFD(ctx, t, childDir)
+ defer unlinkFile(ctx, t, curDir, name, true /* isDir */)
+
+ curDir = childDir
+ dirNames = append(dirNames, name)
+ }
+
+ // Walk all these directories. Add some junk at the end which should not be
+ // walked on.
+ dirNames = append(dirNames, []string{"a", "b", "c"}...)
+ inodes := walk(ctx, t, root, dirNames)
+ if len(inodes) != n {
+ t.Errorf("walk returned the incorrect number of inodes: wanted %d, got %d", n, len(inodes))
+ }
+
+ // Close all control FDs and collect stat results for all dirs including
+ // the root directory.
+ dirStats := make([]linux.Statx, 0, n+1)
+ var stat linux.Statx
+ statTo(ctx, t, root, &stat)
+ dirStats = append(dirStats, stat)
+ for _, inode := range inodes {
+ dirStats = append(dirStats, inode.Stat)
+ closeFD(ctx, t, root.Client().NewFD(inode.ControlFD))
+ }
+
+ // Test WalkStat which additonally returns Statx for root because the first
+ // path component is "".
+ dirNames = append([]string{""}, dirNames...)
+ gotStats := walkStat(ctx, t, root, dirNames)
+ if len(gotStats) != len(dirStats) {
+ t.Errorf("walkStat returned the incorrect number of statx: wanted %d, got %d", len(dirStats), len(gotStats))
+ } else {
+ for i := range gotStats {
+ cmpStatx(t, dirStats[i], gotStats[i])
+ }
+ }
+}
+
+func testRename(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ name := "tempFile"
+ tempFile, _, fd, hostFD := openCreateFile(ctx, t, root, name)
+ defer closeFD(ctx, t, tempFile)
+ defer closeFD(ctx, t, fd)
+ defer unix.Close(hostFD)
+
+ tempDir, _ := mkdir(ctx, t, root, "tempDir")
+ defer closeFD(ctx, t, tempDir)
+
+ // Move tempFile into tempDir.
+ if err := tempFile.RenameTo(ctx, tempDir.ID(), "movedFile"); err != nil {
+ t.Fatalf("rename failed: %v", err)
+ }
+
+ inodes := walkStat(ctx, t, root, []string{"tempDir", "movedFile"})
+ if len(inodes) != 2 {
+ t.Errorf("expected 2 files on walk but only found %d", len(inodes))
+ }
+}
+
+func testMknod(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ name := "namedPipe"
+ pipeFile, pipeStat := mknod(ctx, t, root, name)
+ defer closeFD(ctx, t, pipeFile)
+
+ var stat linux.Statx
+ statTo(ctx, t, pipeFile, &stat)
+
+ if stat.Mode != pipeStat.Mode {
+ t.Errorf("mknod mode is incorrect: want %d, got %d", pipeStat.Mode, stat.Mode)
+ }
+ if stat.UID != pipeStat.UID {
+ t.Errorf("mknod UID is incorrect: want %d, got %d", pipeStat.UID, stat.UID)
+ }
+ if stat.GID != pipeStat.GID {
+ t.Errorf("mknod GID is incorrect: want %d, got %d", pipeStat.GID, stat.GID)
+ }
+}
+
+func testGetdents(ctx context.Context, t *testing.T, tester Tester, root lisafs.ClientFD) {
+ tempDir, _ := mkdir(ctx, t, root, "tempDir")
+ defer closeFD(ctx, t, tempDir)
+ defer unlinkFile(ctx, t, root, "tempDir", true /* isDir */)
+
+ // Create 10 files in tempDir.
+ n := 10
+ fileStats := make(map[string]linux.Statx)
+ for i := 0; i < n; i++ {
+ name := fmt.Sprintf("file-%d", i)
+ newFile, fileStat := mknod(ctx, t, tempDir, name)
+ defer closeFD(ctx, t, newFile)
+ defer unlinkFile(ctx, t, tempDir, name, false /* isDir */)
+
+ fileStats[name] = fileStat
+ }
+
+ // Use opened directory FD for getdents.
+ openDirFile, _ := openFile(ctx, t, tempDir, unix.O_RDONLY, false /* isReg */)
+ defer closeFD(ctx, t, openDirFile)
+
+ dirents := make([]lisafs.Dirent64, 0, n)
+ for i := 0; i < n+2; i++ {
+ gotDirents, err := openDirFile.Getdents64(ctx, 40)
+ if err != nil {
+ t.Fatalf("getdents failed: %v", err)
+ }
+ if len(gotDirents) == 0 {
+ break
+ }
+ for _, dirent := range gotDirents {
+ if dirent.Name != "." && dirent.Name != ".." {
+ dirents = append(dirents, dirent)
+ }
+ }
+ }
+
+ if len(dirents) != n {
+ t.Errorf("got incorrect number of dirents: wanted %d, got %d", n, len(dirents))
+ }
+ for _, dirent := range dirents {
+ stat, ok := fileStats[string(dirent.Name)]
+ if !ok {
+ t.Errorf("received a dirent that was not created: %+v", dirent)
+ continue
+ }
+
+ if dirent.Type != unix.DT_REG {
+ t.Errorf("dirent type of %s is incorrect: %d", dirent.Name, dirent.Type)
+ }
+ if uint64(dirent.Ino) != stat.Ino {
+ t.Errorf("dirent ino of %s is incorrect: want %d, got %d", dirent.Name, stat.Ino, dirent.Ino)
+ }
+ if uint32(dirent.DevMinor) != stat.DevMinor {
+ t.Errorf("dirent dev minor of %s is incorrect: want %d, got %d", dirent.Name, stat.DevMinor, dirent.DevMinor)
+ }
+ if uint32(dirent.DevMajor) != stat.DevMajor {
+ t.Errorf("dirent dev major of %s is incorrect: want %d, got %d", dirent.Name, stat.DevMajor, dirent.DevMajor)
+ }
+ }
+}
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index eb496f02f..d618da820 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -115,7 +115,7 @@ type Client struct {
// channels is the set of all initialized channels.
channels []*channel
- // availableChannels is a FIFO of inactive channels.
+ // availableChannels is a LIFO of inactive channels.
availableChannels []*channel
// -- below corresponds to sendRecvLegacy --
diff --git a/pkg/ring0/defs.go b/pkg/ring0/defs.go
index b6e2012e8..38ce9be1e 100644
--- a/pkg/ring0/defs.go
+++ b/pkg/ring0/defs.go
@@ -77,6 +77,9 @@ type CPU struct {
// calls and exceptions via the Registers function.
registers arch.Registers
+ // floatingPointState holds floating point state.
+ floatingPointState fpu.State
+
// hooks are kernel hooks.
hooks Hooks
}
@@ -90,6 +93,15 @@ func (c *CPU) Registers() *arch.Registers {
return &c.registers
}
+// FloatingPointState returns the kernel floating point state.
+//
+// This is explicitly safe to call during KernelException and KernelSyscall.
+//
+//go:nosplit
+func (c *CPU) FloatingPointState() *fpu.State {
+ return &c.floatingPointState
+}
+
// SwitchOpts are passed to the Switch function.
type SwitchOpts struct {
// Registers are the user register state.
diff --git a/pkg/ring0/defs_amd64.go b/pkg/ring0/defs_amd64.go
index 24f6e4cde..81e90dbf7 100644
--- a/pkg/ring0/defs_amd64.go
+++ b/pkg/ring0/defs_amd64.go
@@ -116,6 +116,11 @@ type CPUArchState struct {
errorType uintptr
*kernelEntry
+
+ // Copies of global variables, stored in CPU so that they can be used by
+ // syscall and exception handlers (in the upper address space).
+ hasXSAVE bool
+ hasXSAVEOPT bool
}
// ErrorCode returns the last error code.
diff --git a/pkg/ring0/entry_amd64.go b/pkg/ring0/entry_amd64.go
index afd646b0b..13ad4e4df 100644
--- a/pkg/ring0/entry_amd64.go
+++ b/pkg/ring0/entry_amd64.go
@@ -39,11 +39,6 @@ func sysenter()
// assembly to get the ABI0 (i.e., primary) address.
func addrOfSysenter() uintptr
-// swapgs swaps the current GS value.
-//
-// This must be called prior to sysret/iret.
-func swapgs()
-
// jumpToKernel jumps to the kernel version of the current RIP.
func jumpToKernel()
diff --git a/pkg/ring0/entry_amd64.s b/pkg/ring0/entry_amd64.s
index 520bd9f57..d2913f190 100644
--- a/pkg/ring0/entry_amd64.s
+++ b/pkg/ring0/entry_amd64.s
@@ -142,8 +142,103 @@ TEXT ·jumpToUser(SB),NOSPLIT,$0
MOVQ AX, 0(SP)
RET
+// See kernel_amd64.go.
+//
+// The 16-byte frame size is for the saved values of MXCSR and the x87 control
+// word.
+TEXT ·doSwitchToUser(SB),NOSPLIT,$16-48
+ // We are passed pointers to heap objects, but do not store them in our
+ // local frame.
+ NO_LOCAL_POINTERS
+
+ // MXCSR and the x87 control word are the only floating point state
+ // that is callee-save and thus we must save.
+ STMXCSR mxcsr-0(SP)
+ FSTCW cw-8(SP)
+
+ // Restore application floating point state.
+ MOVQ cpu+0(FP), SI
+ MOVQ fpState+16(FP), DI
+ MOVB ·hasXSAVE(SB), BX
+ TESTB BX, BX
+ JZ no_xrstor
+ // Use xrstor to restore all available fp state. For now, we restore
+ // everything unconditionally by setting the implicit operand edx:eax
+ // (the "requested feature bitmap") to all 1's.
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x2f // XRSTOR64 0(DI)
+ JMP fprestore_done
+no_xrstor:
+ // Fall back to fxrstor if xsave is not available.
+ FXRSTOR64 0(DI)
+fprestore_done:
+
+ // Set application GS.
+ MOVQ regs+8(FP), R8
+ SWAP_GS()
+ MOVQ PTRACE_GS_BASE(R8), AX
+ PUSHQ AX
+ CALL ·writeGS(SB)
+ POPQ AX
+
+ // Call sysret() or iret().
+ MOVQ userCR3+24(FP), CX
+ MOVQ needIRET+32(FP), R9
+ ADDQ $-32, SP
+ MOVQ SI, 0(SP) // cpu
+ MOVQ R8, 8(SP) // regs
+ MOVQ CX, 16(SP) // userCR3
+ TESTQ R9, R9
+ JNZ do_iret
+ CALL ·sysret(SB)
+ JMP done_sysret_or_iret
+do_iret:
+ CALL ·iret(SB)
+done_sysret_or_iret:
+ MOVQ 24(SP), AX // vector
+ ADDQ $32, SP
+ MOVQ AX, vector+40(FP)
+
+ // Save application floating point state.
+ MOVQ fpState+16(FP), DI
+ MOVB ·hasXSAVE(SB), BX
+ MOVB ·hasXSAVEOPT(SB), CX
+ TESTB BX, BX
+ JZ no_xsave
+ // Use xsave/xsaveopt to save all extended state.
+ // We save everything unconditionally by setting RFBM to all 1's.
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ TESTB CX, CX
+ JZ no_xsaveopt
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x37; // XSAVEOPT64 0(DI)
+ JMP fpsave_done
+no_xsaveopt:
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x27; // XSAVE64 0(DI)
+ JMP fpsave_done
+no_xsave:
+ FXSAVE64 0(DI)
+fpsave_done:
+
+ // Restore MXCSR and the x87 control word after one of the two floating
+ // point save cases above, to ensure the application versions are saved
+ // before being clobbered here.
+ LDMXCSR mxcsr-0(SP)
+
+ // FLDCW is a "waiting" x87 instruction, meaning it checks for pending
+ // unmasked exceptions before executing. Thus if userspace has unmasked
+ // an exception and has one pending, it can be raised by FLDCW even
+ // though the new control word will mask exceptions. To prevent this,
+ // we must first clear pending exceptions (which will be restored by
+ // XRSTOR, et al).
+ BYTE $0xDB; BYTE $0xE2; // FNCLEX
+ FLDCW cw-8(SP)
+
+ RET
+
// See entry_amd64.go.
-TEXT ·sysret(SB),NOSPLIT,$0-24
+TEXT ·sysret(SB),NOSPLIT,$0-32
// Set application FS. We can't do this in Go because Go code needs FS.
MOVQ regs+8(FP), AX
MOVQ PTRACE_FS_BASE(AX), AX
@@ -182,9 +277,11 @@ TEXT ·sysret(SB),NOSPLIT,$0-24
POPQ AX // Restore AX.
POPQ SP // Restore SP.
SYSRET64()
+ // sysenter or exception will write our return value and return to our
+ // caller.
// See entry_amd64.go.
-TEXT ·iret(SB),NOSPLIT,$0-24
+TEXT ·iret(SB),NOSPLIT,$0-32
// Set application FS. We can't do this in Go because Go code needs FS.
MOVQ regs+8(FP), AX
MOVQ PTRACE_FS_BASE(AX), AX
@@ -220,6 +317,8 @@ TEXT ·iret(SB),NOSPLIT,$0-24
WRITE_CR3() // Switch to userCR3.
POPQ AX // Restore AX.
IRET()
+ // sysenter or exception will write our return value and return to our
+ // caller.
// See entry_amd64.go.
TEXT ·resume(SB),NOSPLIT,$0
@@ -324,11 +423,39 @@ kernel:
MOVQ $0, CPU_ERROR_CODE(AX) // Clear error code.
MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel.
+ // Save floating point state. CPU.floatingPointState is a slice, so the
+ // first word of CPU.floatingPointState is a pointer to the destination
+ // array.
+ MOVQ CPU_FPU_STATE(AX), DI
+ MOVB CPU_HAS_XSAVE(AX), BX
+ MOVB CPU_HAS_XSAVEOPT(AX), CX
+ TESTB BX, BX
+ JZ no_xsave
+ // Use xsave/xsaveopt to save all extended state.
+ // We save everything unconditionally by setting RFBM to all 1's.
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ TESTB CX, CX
+ JZ no_xsaveopt
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x37; // XSAVEOPT64 0(DI)
+ JMP fpsave_done
+no_xsaveopt:
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x27; // XSAVE64 0(DI)
+ JMP fpsave_done
+no_xsave:
+ FXSAVE64 0(DI)
+fpsave_done:
+
// Call the syscall trampoline.
LOAD_KERNEL_STACK(GS)
- PUSHQ AX // First argument (vCPU).
- CALL ·kernelSyscall(SB) // Call the trampoline.
- POPQ AX // Pop vCPU.
+ MOVQ ENTRY_CPU_SELF(GS), AX // AX contains the vCPU.
+ PUSHQ AX // First argument (vCPU).
+ CALL ·kernelSyscall(SB) // Call the trampoline.
+ POPQ AX // Pop vCPU.
+
+ // We only trigger a bluepill entry in the bluepill function, and can
+ // therefore be guaranteed that there is no floating point state to be
+ // loaded on resuming from halt.
JMP ·resume(SB)
ADDR_OF_FUNC(·addrOfSysenter(SB), ·sysenter(SB));
@@ -416,15 +543,43 @@ kernel:
MOVQ 8(SP), BX // Load the error code.
MOVQ BX, CPU_ERROR_CODE(AX) // Copy out to the CPU.
MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel.
- MOVQ 0(SP), BX // BX contains the vector.
+
+ // Save floating point state. CPU.floatingPointState is a slice, so the
+ // first word of CPU.floatingPointState is a pointer to the destination
+ // array.
+ MOVQ CPU_FPU_STATE(AX), DI
+ MOVB CPU_HAS_XSAVE(AX), BX
+ MOVB CPU_HAS_XSAVEOPT(AX), CX
+ TESTB BX, BX
+ JZ no_xsave
+ // Use xsave/xsaveopt to save all extended state.
+ // We save everything unconditionally by setting RFBM to all 1's.
+ MOVL $0xffffffff, AX
+ MOVL $0xffffffff, DX
+ TESTB CX, CX
+ JZ no_xsaveopt
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x37; // XSAVEOPT64 0(DI)
+ JMP fpsave_done
+no_xsaveopt:
+ BYTE $0x48; BYTE $0x0f; BYTE $0xae; BYTE $0x27; // XSAVE64 0(DI)
+ JMP fpsave_done
+no_xsave:
+ FXSAVE64 0(DI)
+fpsave_done:
// Call the exception trampoline.
+ MOVQ 0(SP), BX // BX contains the vector.
LOAD_KERNEL_STACK(GS)
- PUSHQ BX // Second argument (vector).
- PUSHQ AX // First argument (vCPU).
- CALL ·kernelException(SB) // Call the trampoline.
- POPQ BX // Pop vector.
- POPQ AX // Pop vCPU.
+ MOVQ ENTRY_CPU_SELF(GS), AX // AX contains the vCPU.
+ PUSHQ BX // Second argument (vector).
+ PUSHQ AX // First argument (vCPU).
+ CALL ·kernelException(SB) // Call the trampoline.
+ POPQ BX // Pop vector.
+ POPQ AX // Pop vCPU.
+
+ // We only trigger a bluepill entry in the bluepill function, and can
+ // therefore be guaranteed that there is no floating point state to be
+ // loaded on resuming from halt.
JMP ·resume(SB)
#define EXCEPTION_WITH_ERROR(value, symbol, addr) \
diff --git a/pkg/ring0/kernel.go b/pkg/ring0/kernel.go
index 292f9d0cc..e7dd84929 100644
--- a/pkg/ring0/kernel.go
+++ b/pkg/ring0/kernel.go
@@ -14,6 +14,10 @@
package ring0
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
+)
+
// Init initializes a new kernel.
//
//go:nosplit
@@ -80,6 +84,7 @@ func (c *CPU) Init(k *Kernel, cpuID int, hooks Hooks) {
c.self = c // Set self reference.
c.kernel = k // Set kernel reference.
c.init(cpuID) // Perform architectural init.
+ c.floatingPointState = fpu.NewState()
// Require hooks.
if hooks != nil {
diff --git a/pkg/ring0/kernel_amd64.go b/pkg/ring0/kernel_amd64.go
index 4a4c0ae26..7e55011b5 100644
--- a/pkg/ring0/kernel_amd64.go
+++ b/pkg/ring0/kernel_amd64.go
@@ -143,6 +143,9 @@ func (c *CPU) init(cpuID int) {
// Set mandatory flags.
c.registers.Eflags = KernelFlagsSet
+
+ c.hasXSAVE = hasXSAVE
+ c.hasXSAVEOPT = hasXSAVEOPT
}
// StackTop returns the kernel's stack address.
@@ -248,19 +251,21 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
regs.Ss = uint64(Udata) // Ditto.
// Perform the switch.
- swapgs() // GS will be swapped on return.
- WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS.
- LoadFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy in floating point.
+ needIRET := uint64(0)
if switchOpts.FullRestore {
- vector = iret(c, regs, uintptr(userCR3))
- } else {
- vector = sysret(c, regs, uintptr(userCR3))
+ needIRET = 1
}
- SaveFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy out floating point.
- RestoreKernelFPState() // escapes: no. Restore kernel MXCSR.
+ vector = doSwitchToUser(c, regs, switchOpts.FloatingPointState.BytePointer(), userCR3, needIRET) // escapes: no.
return
}
+func doSwitchToUser(
+ cpu *CPU, // +0(FP)
+ regs *arch.Registers, // +8(FP)
+ fpState *byte, // +16(FP)
+ userCR3 uint64, // +24(FP)
+ needIRET uint64) Vector // +32(FP), +40(FP)
+
var (
sentryXCR0 uintptr
sentryXCR0Once sync.Once
@@ -287,7 +292,7 @@ func initSentryXCR0() {
//go:nosplit
func startGo(c *CPU) {
// Save per-cpu.
- WriteGS(kernelAddr(c.kernelEntry))
+ writeGS(kernelAddr(c.kernelEntry))
//
// TODO(mpratt): Note that per the note above, this should be done
diff --git a/pkg/ring0/lib_amd64.go b/pkg/ring0/lib_amd64.go
index 05c394ff5..c42a5b205 100644
--- a/pkg/ring0/lib_amd64.go
+++ b/pkg/ring0/lib_amd64.go
@@ -21,29 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/cpuid"
)
-// LoadFloatingPoint loads floating point state by the most efficient mechanism
-// available (set by Init).
-var LoadFloatingPoint func(*byte)
-
-// SaveFloatingPoint saves floating point state by the most efficient mechanism
-// available (set by Init).
-var SaveFloatingPoint func(*byte)
-
-// fxrstor uses fxrstor64 to load floating point state.
-func fxrstor(*byte)
-
-// xrstor uses xrstor to load floating point state.
-func xrstor(*byte)
-
-// fxsave uses fxsave64 to save floating point state.
-func fxsave(*byte)
-
-// xsave uses xsave to save floating point state.
-func xsave(*byte)
-
-// xsaveopt uses xsaveopt to save floating point state.
-func xsaveopt(*byte)
-
// writeFS sets the FS base address (selects one of wrfsbase or wrfsmsr).
func writeFS(addr uintptr)
@@ -53,8 +30,8 @@ func wrfsbase(addr uintptr)
// wrfsmsr writes to the GS_BASE MSR.
func wrfsmsr(addr uintptr)
-// WriteGS sets the GS address (set by init).
-var WriteGS func(addr uintptr)
+// writeGS sets the GS address (selects one of wrgsbase or wrgsmsr).
+func writeGS(addr uintptr)
// wrgsbase writes to the GS base address.
func wrgsbase(addr uintptr)
@@ -106,19 +83,4 @@ func Init(featureSet *cpuid.FeatureSet) {
hasXSAVE = featureSet.UseXsave()
hasFSGSBASE = featureSet.HasFeature(cpuid.X86FeatureFSGSBase)
validXCR0Mask = uintptr(featureSet.ValidXCR0Mask())
- if hasXSAVEOPT {
- SaveFloatingPoint = xsaveopt
- LoadFloatingPoint = xrstor
- } else if hasXSAVE {
- SaveFloatingPoint = xsave
- LoadFloatingPoint = xrstor
- } else {
- SaveFloatingPoint = fxsave
- LoadFloatingPoint = fxrstor
- }
- if hasFSGSBASE {
- WriteGS = wrgsbase
- } else {
- WriteGS = wrgsmsr
- }
}
diff --git a/pkg/ring0/lib_amd64.s b/pkg/ring0/lib_amd64.s
index 8ed98fc84..0f283aaae 100644
--- a/pkg/ring0/lib_amd64.s
+++ b/pkg/ring0/lib_amd64.s
@@ -128,6 +128,29 @@ TEXT ·wrfsmsr(SB),NOSPLIT,$0-8
BYTE $0x0f; BYTE $0x30;
RET
+// writeGS writes to the GS base.
+//
+// This is written in assembly because it must be callable from assembly (ABI0)
+// without an intermediate transition to ABIInternal.
+//
+// Preconditions: must be running in the lower address space, as it accesses
+// global data.
+TEXT ·writeGS(SB),NOSPLIT,$8-8
+ MOVQ addr+0(FP), AX
+
+ CMPB ·hasFSGSBASE(SB), $1
+ JNE msr
+
+ PUSHQ AX
+ CALL ·wrgsbase(SB)
+ POPQ AX
+ RET
+msr:
+ PUSHQ AX
+ CALL ·wrgsmsr(SB)
+ POPQ AX
+ RET
+
// wrgsbase writes to the GS base.
//
// The code corresponds to:
diff --git a/pkg/ring0/offsets_amd64.go b/pkg/ring0/offsets_amd64.go
index 75f6218b3..38fe27c35 100644
--- a/pkg/ring0/offsets_amd64.go
+++ b/pkg/ring0/offsets_amd64.go
@@ -35,6 +35,9 @@ func Emit(w io.Writer) {
fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer())
fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer())
fmt.Fprintf(w, "#define CPU_ENTRY 0x%02x\n", reflect.ValueOf(&c.kernelEntry).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_HAS_XSAVE 0x%02x\n", reflect.ValueOf(&c.hasXSAVE).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_HAS_XSAVEOPT 0x%02x\n", reflect.ValueOf(&c.hasXSAVEOPT).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_FPU_STATE 0x%02x\n", reflect.ValueOf(&c.floatingPointState).Pointer()-reflect.ValueOf(c).Pointer())
e := &kernelEntry{}
fmt.Fprintf(w, "\n// CPU entry offsets.\n")
diff --git a/pkg/ring0/pagetables/pagetables.go b/pkg/ring0/pagetables/pagetables.go
index 9dac53c80..3f17fba49 100644
--- a/pkg/ring0/pagetables/pagetables.go
+++ b/pkg/ring0/pagetables/pagetables.go
@@ -322,12 +322,3 @@ func (p *PageTables) Lookup(addr hostarch.Addr, findFirst bool) (virtual hostarc
func (p *PageTables) MarkReadOnlyShared() {
p.readOnlyShared = true
}
-
-// PrefaultRootTable touches the root table page to be sure that its physical
-// pages are mapped.
-//
-//go:nosplit
-//go:noinline
-func (p *PageTables) PrefaultRootTable() PTE {
- return p.root[0]
-}
diff --git a/pkg/safecopy/BUILD b/pkg/safecopy/BUILD
index 0a045fc8e..2a1602e2b 100644
--- a/pkg/safecopy/BUILD
+++ b/pkg/safecopy/BUILD
@@ -18,9 +18,9 @@ go_library(
],
visibility = ["//:sandbox"],
deps = [
- "//pkg/abi/linux",
"//pkg/errors",
"//pkg/errors/linuxerr",
+ "//pkg/sighandling",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/safecopy/safecopy.go b/pkg/safecopy/safecopy.go
index a9711e63d..0dd0aea83 100644
--- a/pkg/safecopy/safecopy.go
+++ b/pkg/safecopy/safecopy.go
@@ -23,6 +23,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/errors"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/sighandling"
)
// SegvError is returned when a safecopy function receives SIGSEGV.
@@ -132,10 +133,10 @@ func initializeAddresses() {
func init() {
initializeAddresses()
- if err := ReplaceSignalHandler(unix.SIGSEGV, addrOfSignalHandler(), &savedSigSegVHandler); err != nil {
+ if err := sighandling.ReplaceSignalHandler(unix.SIGSEGV, addrOfSignalHandler(), &savedSigSegVHandler); err != nil {
panic(fmt.Sprintf("Unable to set handler for SIGSEGV: %v", err))
}
- if err := ReplaceSignalHandler(unix.SIGBUS, addrOfSignalHandler(), &savedSigBusHandler); err != nil {
+ if err := sighandling.ReplaceSignalHandler(unix.SIGBUS, addrOfSignalHandler(), &savedSigBusHandler); err != nil {
panic(fmt.Sprintf("Unable to set handler for SIGBUS: %v", err))
}
linuxerr.AddErrorUnwrapper(func(e error) (*errors.Error, bool) {
diff --git a/pkg/safecopy/safecopy_unsafe.go b/pkg/safecopy/safecopy_unsafe.go
index 2365b2c0d..15f84abea 100644
--- a/pkg/safecopy/safecopy_unsafe.go
+++ b/pkg/safecopy/safecopy_unsafe.go
@@ -20,7 +20,6 @@ import (
"unsafe"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/abi/linux"
)
// maxRegisterSize is the maximum register size used in memcpy and memclr. It
@@ -332,39 +331,3 @@ func errorFromFaultSignal(addr uintptr, sig int32) error {
panic(fmt.Sprintf("safecopy got unexpected signal %d at address %#x", sig, addr))
}
}
-
-// ReplaceSignalHandler replaces the existing signal handler for the provided
-// signal with the one that handles faults in safecopy-protected functions.
-//
-// It stores the value of the previously set handler in previous.
-//
-// This function will be called on initialization in order to install safecopy
-// handlers for appropriate signals. These handlers will call the previous
-// handler however, and if this is function is being used externally then the
-// same courtesy is expected.
-func ReplaceSignalHandler(sig unix.Signal, handler uintptr, previous *uintptr) error {
- var sa linux.SigAction
- const maskLen = 8
-
- // Get the existing signal handler information, and save the current
- // handler. Once we replace it, we will use this pointer to fall back to
- // it when we receive other signals.
- if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), 0, uintptr(unsafe.Pointer(&sa)), maskLen, 0, 0); e != 0 {
- return e
- }
-
- // Fail if there isn't a previous handler.
- if sa.Handler == 0 {
- return fmt.Errorf("previous handler for signal %x isn't set", sig)
- }
-
- *previous = uintptr(sa.Handler)
-
- // Install our own handler.
- sa.Handler = uint64(handler)
- if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 {
- return e
- }
-
- return nil
-}
diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go
index 2f3664c57..f721b7236 100644
--- a/pkg/sentry/control/pprof.go
+++ b/pkg/sentry/control/pprof.go
@@ -26,6 +26,23 @@ import (
"gvisor.dev/gvisor/pkg/urpc"
)
+const (
+ // DefaultBlockProfileRate is the default profiling rate for block
+ // profiles.
+ //
+ // The default here is 10%, which will record a stacktrace 10% of the
+ // time when blocking occurs. Since these events should not be super
+ // frequent, we expect this to achieve a reasonable balance between
+ // collecting the data we need and imposing a high performance cost
+ // (e.g. skewing even the CPU profile).
+ DefaultBlockProfileRate = 10
+
+ // DefaultMutexProfileRate is the default profiling rate for mutex
+ // profiles. Like the block rate above, we use a default rate of 10%
+ // for the same reasons.
+ DefaultMutexProfileRate = 10
+)
+
// Profile includes profile-related RPC stubs. It provides a way to
// control the built-in runtime profiling facilities.
//
@@ -175,12 +192,8 @@ func (p *Profile) Block(o *BlockProfileOpts, _ *struct{}) error {
defer p.blockMu.Unlock()
// Always set the rate. We then wait to collect a profile at this rate,
- // and disable when we're done. Note that the default here is 10%, which
- // will record a stacktrace 10% of the time when blocking occurs. Since
- // these events should not be super frequent, we expect this to achieve
- // a reasonable balance between collecting the data we need and imposing
- // a high performance cost (e.g. skewing even the CPU profile).
- rate := 10
+ // and disable when we're done.
+ rate := DefaultBlockProfileRate
if o.Rate != 0 {
rate = o.Rate
}
@@ -220,9 +233,8 @@ func (p *Profile) Mutex(o *MutexProfileOpts, _ *struct{}) error {
p.mutexMu.Lock()
defer p.mutexMu.Unlock()
- // Always set the fraction. Like the block rate above, we use
- // a default rate of 10% for the same reasons.
- fraction := 10
+ // Always set the fraction.
+ fraction := DefaultMutexProfileRate
if o.Fraction != 0 {
fraction = o.Fraction
}
diff --git a/pkg/sentry/fs/fdpipe/pipe.go b/pkg/sentry/fs/fdpipe/pipe.go
index 4370cce33..d2eb03bb7 100644
--- a/pkg/sentry/fs/fdpipe/pipe.go
+++ b/pkg/sentry/fs/fdpipe/pipe.go
@@ -45,7 +45,8 @@ type pipeOperations struct {
fsutil.FileNoIoctl `state:"nosave"`
fsutil.FileNoSplice `state:"nosave"`
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
- waiter.Queue `state:"nosave"`
+
+ waiter.Queue
// flags are the flags used to open the pipe.
flags fs.FileFlags `state:".(fs.FileFlags)"`
diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go
index 031cd33ce..a27dd0b9a 100644
--- a/pkg/sentry/fs/file_overlay.go
+++ b/pkg/sentry/fs/file_overlay.go
@@ -16,6 +16,7 @@ package fs
import (
"io"
+ "math"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
@@ -360,10 +361,13 @@ func (*overlayFileOperations) ConfigureMMap(ctx context.Context, file *File, opt
return linuxerr.ENODEV
}
- // FIXME(jamieliu): This is a copy/paste of fsutil.GenericConfigureMMap,
- // which we can't use because the overlay implementation is in package fs,
- // so depending on fs/fsutil would create a circular dependency. Move
- // overlay to fs/overlay.
+ // TODO(gvisor.dev/issue/1624): This is a copy/paste of
+ // fsutil.GenericConfigureMMap, which we can't use because the overlay
+ // implementation is in package fs, so depending on fs/fsutil would create
+ // a circular dependency. VFS2 overlay doesn't have this issue.
+ if opts.Offset+opts.Length > math.MaxInt64 {
+ return linuxerr.EOVERFLOW
+ }
opts.Mappable = o
opts.MappingIdentity = file
file.IncRef()
diff --git a/pkg/sentry/fs/fsutil/file.go b/pkg/sentry/fs/fsutil/file.go
index 3ece73b81..38e3ed42d 100644
--- a/pkg/sentry/fs/fsutil/file.go
+++ b/pkg/sentry/fs/fsutil/file.go
@@ -16,6 +16,7 @@ package fsutil
import (
"io"
+ "math"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
@@ -210,6 +211,9 @@ func (FileNoMMap) ConfigureMMap(context.Context, *fs.File, *memmap.MMapOpts) err
// GenericConfigureMMap implements fs.FileOperations.ConfigureMMap for most
// filesystems that support memory mapping.
func GenericConfigureMMap(file *fs.File, m memmap.Mappable, opts *memmap.MMapOpts) error {
+ if opts.Offset+opts.Length > math.MaxInt64 {
+ return linuxerr.EOVERFLOW
+ }
opts.Mappable = m
opts.MappingIdentity = file
file.IncRef()
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go
index 23528bf25..37ddb1a3c 100644
--- a/pkg/sentry/fs/fsutil/host_file_mapper.go
+++ b/pkg/sentry/fs/fsutil/host_file_mapper.go
@@ -93,7 +93,8 @@ func NewHostFileMapper() *HostFileMapper {
func (f *HostFileMapper) IncRefOn(mr memmap.MappableRange) {
f.refsMu.Lock()
defer f.refsMu.Unlock()
- for chunkStart := mr.Start &^ chunkMask; chunkStart < mr.End; chunkStart += chunkSize {
+ chunkStart := mr.Start &^ chunkMask
+ for {
refs := f.refs[chunkStart]
pgs := pagesInChunk(mr, chunkStart)
if refs+pgs < refs {
@@ -101,6 +102,10 @@ func (f *HostFileMapper) IncRefOn(mr memmap.MappableRange) {
panic(fmt.Sprintf("HostFileMapper.IncRefOn(%v): adding %d page references to chunk %#x, which has %d page references", mr, pgs, chunkStart, refs))
}
f.refs[chunkStart] = refs + pgs
+ chunkStart += chunkSize
+ if chunkStart >= mr.End || chunkStart == 0 {
+ break
+ }
}
}
@@ -112,7 +117,8 @@ func (f *HostFileMapper) IncRefOn(mr memmap.MappableRange) {
func (f *HostFileMapper) DecRefOn(mr memmap.MappableRange) {
f.refsMu.Lock()
defer f.refsMu.Unlock()
- for chunkStart := mr.Start &^ chunkMask; chunkStart < mr.End; chunkStart += chunkSize {
+ chunkStart := mr.Start &^ chunkMask
+ for {
refs := f.refs[chunkStart]
pgs := pagesInChunk(mr, chunkStart)
switch {
@@ -128,6 +134,10 @@ func (f *HostFileMapper) DecRefOn(mr memmap.MappableRange) {
case refs < pgs:
panic(fmt.Sprintf("HostFileMapper.DecRefOn(%v): removing %d page references from chunk %#x, which has %d page references", mr, pgs, chunkStart, refs))
}
+ chunkStart += chunkSize
+ if chunkStart >= mr.End || chunkStart == 0 {
+ break
+ }
}
}
@@ -161,7 +171,8 @@ func (f *HostFileMapper) forEachMappingBlockLocked(fr memmap.FileRange, fd int,
if write {
prot |= unix.PROT_WRITE
}
- for chunkStart := fr.Start &^ chunkMask; chunkStart < fr.End; chunkStart += chunkSize {
+ chunkStart := fr.Start &^ chunkMask
+ for {
m, ok := f.mappings[chunkStart]
if !ok {
addr, _, errno := unix.Syscall6(
@@ -201,6 +212,10 @@ func (f *HostFileMapper) forEachMappingBlockLocked(fr memmap.FileRange, fd int,
endOff = fr.End - chunkStart
}
fn(f.unsafeBlockFromChunkMapping(m.addr).TakeFirst64(endOff).DropFirst64(startOff))
+ chunkStart += chunkSize
+ if chunkStart >= fr.End || chunkStart == 0 {
+ break
+ }
}
return nil
}
diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go
index 92d58e3e9..99c37291e 100644
--- a/pkg/sentry/fs/host/inode.go
+++ b/pkg/sentry/fs/host/inode.go
@@ -70,7 +70,7 @@ type inodeFileState struct {
descriptor *descriptor `state:"wait"`
// Event queue for blocking operations.
- queue waiter.Queue `state:"zerovalue"`
+ queue waiter.Queue
// sattr is used to restore the inodeOperations.
sattr fs.StableAttr `state:"wait"`
diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go
index 51cd6cd37..941f37116 100644
--- a/pkg/sentry/fs/inotify.go
+++ b/pkg/sentry/fs/inotify.go
@@ -43,7 +43,7 @@ type Inotify struct {
// user, since we may aggressively reuse an id on S/R.
id uint64
- waiter.Queue `state:"nosave"`
+ waiter.Queue
// evMu *only* protects the events list. We need a separate lock because
// while queuing events, a watch needs to lock the event queue, and using mu
diff --git a/pkg/sentry/fs/lock/lock.go b/pkg/sentry/fs/lock/lock.go
index 7d7a207cc..e39d340fe 100644
--- a/pkg/sentry/fs/lock/lock.go
+++ b/pkg/sentry/fs/lock/lock.go
@@ -132,7 +132,7 @@ type Locks struct {
locks LockSet
// blockedQueue is the queue of waiters that are waiting on a lock.
- blockedQueue waiter.Queue `state:"zerovalue"`
+ blockedQueue waiter.Queue
}
// Blocker is the interface used for blocking locks. Passing a nil Blocker
diff --git a/pkg/sentry/fs/proc/sys.go b/pkg/sentry/fs/proc/sys.go
index 085aa6d61..443b9a94c 100644
--- a/pkg/sentry/fs/proc/sys.go
+++ b/pkg/sentry/fs/proc/sys.go
@@ -109,6 +109,9 @@ func (p *proc) newKernelDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode
"shmall": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMALL, 10))),
"shmmax": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMAX, 10))),
"shmmni": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMNI, 10))),
+ "msgmni": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.MSGMNI, 10))),
+ "msgmax": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.MSGMAX, 10))),
+ "msgmnb": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.MSGMNB, 10))),
}
d := ramfs.NewDir(ctx, children, fs.RootOwner, fs.FilePermsFromMode(0555))
diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go
index 1c8518d71..ca8be8683 100644
--- a/pkg/sentry/fs/timerfd/timerfd.go
+++ b/pkg/sentry/fs/timerfd/timerfd.go
@@ -43,7 +43,7 @@ type TimerOperations struct {
fsutil.FileNoopFlush `state:"nosave"`
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
- events waiter.Queue `state:"zerovalue"`
+ events waiter.Queue
timer *ktime.Timer
// val is the number of timer expirations since the last successful call to
diff --git a/pkg/sentry/fs/tty/line_discipline.go b/pkg/sentry/fs/tty/line_discipline.go
index f9fca6d8e..f2c9e9668 100644
--- a/pkg/sentry/fs/tty/line_discipline.go
+++ b/pkg/sentry/fs/tty/line_discipline.go
@@ -102,10 +102,10 @@ type lineDiscipline struct {
column int
// masterWaiter is used to wait on the master end of the TTY.
- masterWaiter waiter.Queue `state:"zerovalue"`
+ masterWaiter waiter.Queue
// replicaWaiter is used to wait on the replica end of the TTY.
- replicaWaiter waiter.Queue `state:"zerovalue"`
+ replicaWaiter waiter.Queue
}
func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline {
diff --git a/pkg/sentry/fsimpl/cgroupfs/BUILD b/pkg/sentry/fsimpl/cgroupfs/BUILD
index e5fdcc776..60ee5ede2 100644
--- a/pkg/sentry/fsimpl/cgroupfs/BUILD
+++ b/pkg/sentry/fsimpl/cgroupfs/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
licenses(["notice"])
@@ -18,6 +18,7 @@ go_library(
name = "cgroupfs",
srcs = [
"base.go",
+ "bitmap.go",
"cgroupfs.go",
"cpu.go",
"cpuacct.go",
@@ -29,10 +30,12 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
+ "//pkg/bitmap",
"//pkg/context",
"//pkg/coverage",
"//pkg/errors/linuxerr",
"//pkg/fspath",
+ "//pkg/hostarch",
"//pkg/log",
"//pkg/refs",
"//pkg/refsvfs2",
@@ -47,3 +50,11 @@ go_library(
"//pkg/usermem",
],
)
+
+go_test(
+ name = "cgroupfs_test",
+ size = "small",
+ srcs = ["bitmap_test.go"],
+ library = ":cgroupfs",
+ deps = ["//pkg/bitmap"],
+)
diff --git a/pkg/sentry/fsimpl/cgroupfs/bitmap.go b/pkg/sentry/fsimpl/cgroupfs/bitmap.go
new file mode 100644
index 000000000..8074641db
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/bitmap.go
@@ -0,0 +1,139 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cgroupfs
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/bitmap"
+)
+
+// formatBitmap produces a string representation of b, which lists the indicies
+// of set bits in the bitmap. Indicies are separated by commas and ranges of
+// set bits are abbreviated. Example outputs: "0,2,4", "0,3-7,10", "0-10".
+//
+// Inverse of parseBitmap.
+func formatBitmap(b *bitmap.Bitmap) string {
+ ones := b.ToSlice()
+ if len(ones) == 0 {
+ return ""
+ }
+
+ elems := make([]string, 0, len(ones))
+ runStart := ones[0]
+ lastVal := ones[0]
+ inRun := false
+
+ for _, v := range ones[1:] {
+ last := lastVal
+ lastVal = v
+
+ if last+1 == v {
+ // In a contiguous block of ones.
+ if !inRun {
+ runStart = last
+ inRun = true
+ }
+
+ continue
+ }
+
+ // Non-contiguous bit.
+ if inRun {
+ // Render a run
+ elems = append(elems, fmt.Sprintf("%d-%d", runStart, last))
+ inRun = false
+ continue
+ }
+
+ // Lone non-contiguous bit.
+ elems = append(elems, fmt.Sprintf("%d", last))
+
+ }
+
+ // Process potential final run
+ if inRun {
+ elems = append(elems, fmt.Sprintf("%d-%d", runStart, lastVal))
+ } else {
+ elems = append(elems, fmt.Sprintf("%d", lastVal))
+ }
+
+ return strings.Join(elems, ",")
+}
+
+func parseToken(token string) (start, end uint32, err error) {
+ ts := strings.SplitN(token, "-", 2)
+ switch len(ts) {
+ case 0:
+ return 0, 0, fmt.Errorf("invalid token %q", token)
+ case 1:
+ val, err := strconv.ParseUint(ts[0], 10, 32)
+ if err != nil {
+ return 0, 0, err
+ }
+ return uint32(val), uint32(val), nil
+ case 2:
+ val1, err := strconv.ParseUint(ts[0], 10, 32)
+ if err != nil {
+ return 0, 0, err
+ }
+ val2, err := strconv.ParseUint(ts[1], 10, 32)
+ if err != nil {
+ return 0, 0, err
+ }
+ if val1 >= val2 {
+ return 0, 0, fmt.Errorf("start (%v) must be less than end (%v)", val1, val2)
+ }
+ return uint32(val1), uint32(val2), nil
+ default:
+ panic(fmt.Sprintf("Unreachable: got %d substrs", len(ts)))
+ }
+}
+
+// parseBitmap parses input as a bitmap. input should be a comma separated list
+// of indices, and ranges of set bits may be abbreviated. Examples: "0,2,4",
+// "0,3-7,10", "0-10". Input after the first newline or null byte is discarded.
+//
+// sizeHint sets the initial size of the bitmap, which may prevent reallocation
+// when growing the bitmap during parsing. Ideally sizeHint should be at least
+// as large as the bitmap represented by input, but this is not required.
+//
+// Inverse of formatBitmap.
+func parseBitmap(input string, sizeHint uint32) (*bitmap.Bitmap, error) {
+ b := bitmap.New(sizeHint)
+
+ if termIdx := strings.IndexAny(input, "\n\000"); termIdx != -1 {
+ input = input[:termIdx]
+ }
+ input = strings.TrimSpace(input)
+
+ if len(input) == 0 {
+ return &b, nil
+ }
+ tokens := strings.Split(input, ",")
+
+ for _, t := range tokens {
+ start, end, err := parseToken(strings.TrimSpace(t))
+ if err != nil {
+ return nil, err
+ }
+ for i := start; i <= end; i++ {
+ b.Add(i)
+ }
+ }
+ return &b, nil
+}
diff --git a/pkg/sentry/fsimpl/cgroupfs/bitmap_test.go b/pkg/sentry/fsimpl/cgroupfs/bitmap_test.go
new file mode 100644
index 000000000..5cc56de3b
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/bitmap_test.go
@@ -0,0 +1,99 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cgroupfs
+
+import (
+ "fmt"
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/bitmap"
+)
+
+func TestFormat(t *testing.T) {
+ tests := []struct {
+ input []uint32
+ output string
+ }{
+ {[]uint32{1, 2, 3, 4, 7}, "1-4,7"},
+ {[]uint32{2}, "2"},
+ {[]uint32{0, 1, 2}, "0-2"},
+ {[]uint32{}, ""},
+ {[]uint32{1, 3, 4, 5, 6, 9, 11, 13, 14, 15, 16, 17}, "1,3-6,9,11,13-17"},
+ {[]uint32{2, 3, 10, 12, 13, 14, 15, 16, 20, 21, 33, 34, 47}, "2-3,10,12-16,20-21,33-34,47"},
+ }
+ for i, tt := range tests {
+ t.Run(fmt.Sprintf("case-%d", i), func(t *testing.T) {
+ b := bitmap.New(64)
+ for _, v := range tt.input {
+ b.Add(v)
+ }
+ s := formatBitmap(&b)
+ if s != tt.output {
+ t.Errorf("Expected %q, got %q", tt.output, s)
+ }
+ b1, err := parseBitmap(s, 64)
+ if err != nil {
+ t.Fatalf("Failed to parse formatted bitmap: %v", err)
+ }
+ if got, want := b1.ToSlice(), b.ToSlice(); !reflect.DeepEqual(got, want) {
+ t.Errorf("Parsing formatted output doesn't result in the original bitmap. Got %v, want %v", got, want)
+ }
+ })
+ }
+}
+
+func TestParse(t *testing.T) {
+ tests := []struct {
+ input string
+ output []uint32
+ shouldFail bool
+ }{
+ {"1", []uint32{1}, false},
+ {"", []uint32{}, false},
+ {"1,2,3,4", []uint32{1, 2, 3, 4}, false},
+ {"1-4", []uint32{1, 2, 3, 4}, false},
+ {"1,2-4", []uint32{1, 2, 3, 4}, false},
+ {"1,2-3,4", []uint32{1, 2, 3, 4}, false},
+ {"1-2,3,4,10,11", []uint32{1, 2, 3, 4, 10, 11}, false},
+ {"1,2-4,5,16", []uint32{1, 2, 3, 4, 5, 16}, false},
+ {"abc", []uint32{}, true},
+ {"1,3-2,4", []uint32{}, true},
+ {"1,3-3,4", []uint32{}, true},
+ {"1,2,3\000,4", []uint32{1, 2, 3}, false},
+ {"1,2,3\n,4", []uint32{1, 2, 3}, false},
+ }
+ for i, tt := range tests {
+ t.Run(fmt.Sprintf("case-%d", i), func(t *testing.T) {
+ b, err := parseBitmap(tt.input, 64)
+ if tt.shouldFail {
+ if err == nil {
+ t.Fatalf("Expected parsing of %q to fail, but it didn't", tt.input)
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("Failed to parse bitmap: %v", err)
+ return
+ }
+
+ got := b.ToSlice()
+ if !reflect.DeepEqual(got, tt.output) {
+ t.Errorf("Parsed bitmap doesn't match what we expected. Got %v, want %v", got, tt.output)
+ }
+
+ })
+ }
+}
diff --git a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
index edc3b50b9..e089b2c28 100644
--- a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
+++ b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
@@ -269,7 +269,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
case controllerCPUAcct:
c = newCPUAcctController(fs)
case controllerCPUSet:
- c = newCPUSetController(fs)
+ c = newCPUSetController(k, fs)
case controllerJob:
c = newJobController(fs)
case controllerMemory:
diff --git a/pkg/sentry/fsimpl/cgroupfs/cpuset.go b/pkg/sentry/fsimpl/cgroupfs/cpuset.go
index ac547f8e2..62e7029da 100644
--- a/pkg/sentry/fsimpl/cgroupfs/cpuset.go
+++ b/pkg/sentry/fsimpl/cgroupfs/cpuset.go
@@ -15,25 +15,133 @@
package cgroupfs
import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/bitmap"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// +stateify savable
type cpusetController struct {
controllerCommon
+
+ maxCpus uint32
+ maxMems uint32
+
+ cpus *bitmap.Bitmap
+ mems *bitmap.Bitmap
}
var _ controller = (*cpusetController)(nil)
-func newCPUSetController(fs *filesystem) *cpusetController {
- c := &cpusetController{}
+func newCPUSetController(k *kernel.Kernel, fs *filesystem) *cpusetController {
+ cores := uint32(k.ApplicationCores())
+ cpus := bitmap.New(cores)
+ cpus.FlipRange(0, cores)
+ mems := bitmap.New(1)
+ mems.FlipRange(0, 1)
+ c := &cpusetController{
+ cpus: &cpus,
+ mems: &mems,
+ maxCpus: uint32(k.ApplicationCores()),
+ maxMems: 1, // We always report a single NUMA node.
+ }
c.controllerCommon.init(controllerCPUSet, fs)
return c
}
// AddControlFiles implements controller.AddControlFiles.
func (c *cpusetController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) {
- // This controller is currently intentionally empty.
+ contents["cpuset.cpus"] = c.fs.newControllerWritableFile(ctx, creds, &cpusData{c: c})
+ contents["cpuset.mems"] = c.fs.newControllerWritableFile(ctx, creds, &memsData{c: c})
+}
+
+// +stateify savable
+type cpusData struct {
+ c *cpusetController
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *cpusData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "%s\n", formatBitmap(d.c.cpus))
+ return nil
+}
+
+// Write implements vfs.WritableDynamicBytesSource.Write.
+func (d *cpusData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ src = src.DropFirst64(offset)
+ if src.NumBytes() > hostarch.PageSize {
+ return 0, linuxerr.EINVAL
+ }
+
+ t := kernel.TaskFromContext(ctx)
+ buf := t.CopyScratchBuffer(hostarch.PageSize)
+ n, err := src.CopyIn(ctx, buf)
+ if err != nil {
+ return 0, err
+ }
+ buf = buf[:n]
+
+ b, err := parseBitmap(string(buf), d.c.maxCpus)
+ if err != nil {
+ log.Warningf("cgroupfs cpuset controller: Failed to parse bitmap: %v", err)
+ return 0, linuxerr.EINVAL
+ }
+
+ if got, want := b.Maximum(), d.c.maxCpus; got > want {
+ log.Warningf("cgroupfs cpuset controller: Attempted to specify cpuset.cpus beyond highest available cpu: got %d, want %d", got, want)
+ return 0, linuxerr.EINVAL
+ }
+
+ d.c.cpus = b
+ return int64(n), nil
+}
+
+// +stateify savable
+type memsData struct {
+ c *cpusetController
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *memsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "%s\n", formatBitmap(d.c.mems))
+ return nil
+}
+
+// Write implements vfs.WritableDynamicBytesSource.Write.
+func (d *memsData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ src = src.DropFirst64(offset)
+ if src.NumBytes() > hostarch.PageSize {
+ return 0, linuxerr.EINVAL
+ }
+
+ t := kernel.TaskFromContext(ctx)
+ buf := t.CopyScratchBuffer(hostarch.PageSize)
+ n, err := src.CopyIn(ctx, buf)
+ if err != nil {
+ return 0, err
+ }
+ buf = buf[:n]
+
+ b, err := parseBitmap(string(buf), d.c.maxMems)
+ if err != nil {
+ log.Warningf("cgroupfs cpuset controller: Failed to parse bitmap: %v", err)
+ return 0, linuxerr.EINVAL
+ }
+
+ if got, want := b.Maximum(), d.c.maxMems; got > want {
+ log.Warningf("cgroupfs cpuset controller: Attempted to specify cpuset.mems beyond highest available node: got %d, want %d", got, want)
+ return 0, linuxerr.EINVAL
+ }
+
+ d.c.mems = b
+ return int64(n), nil
}
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
index 4244f2cf5..509dd0e1a 100644
--- a/pkg/sentry/fsimpl/gofer/BUILD
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -54,7 +54,10 @@ go_library(
"//pkg/fdnotifier",
"//pkg/fspath",
"//pkg/hostarch",
+ "//pkg/lisafs",
"//pkg/log",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/metric",
"//pkg/p9",
"//pkg/refs",
diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go
index 5c48a9fee..d99a6112c 100644
--- a/pkg/sentry/fsimpl/gofer/directory.go
+++ b/pkg/sentry/fsimpl/gofer/directory.go
@@ -222,47 +222,88 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) {
off := uint64(0)
const count = 64 * 1024 // for consistency with the vfs1 client
d.handleMu.RLock()
- if d.readFile.isNil() {
+ if !d.isReadFileOk() {
// This should not be possible because a readable handle should
// have been opened when the calling directoryFD was opened.
d.handleMu.RUnlock()
panic("gofer.dentry.getDirents called without a readable handle")
}
+ // shouldSeek0 indicates whether the server should SEEK to 0 before reading
+ // directory entries.
+ shouldSeek0 := true
for {
- p9ds, err := d.readFile.readdir(ctx, off, count)
- if err != nil {
- d.handleMu.RUnlock()
- return nil, err
- }
- if len(p9ds) == 0 {
- d.handleMu.RUnlock()
- break
- }
- for _, p9d := range p9ds {
- if p9d.Name == "." || p9d.Name == ".." {
- continue
+ if d.fs.opts.lisaEnabled {
+ countLisa := int32(count)
+ if shouldSeek0 {
+ // See lisafs.Getdents64Req.Count.
+ countLisa = -countLisa
+ shouldSeek0 = false
+ }
+ lisafsDs, err := d.readFDLisa.Getdents64(ctx, countLisa)
+ if err != nil {
+ d.handleMu.RUnlock()
+ return nil, err
+ }
+ if len(lisafsDs) == 0 {
+ d.handleMu.RUnlock()
+ break
+ }
+ for i := range lisafsDs {
+ name := string(lisafsDs[i].Name)
+ if name == "." || name == ".." {
+ continue
+ }
+ dirent := vfs.Dirent{
+ Name: name,
+ Ino: d.fs.inoFromKey(inoKey{
+ ino: uint64(lisafsDs[i].Ino),
+ devMinor: uint32(lisafsDs[i].DevMinor),
+ devMajor: uint32(lisafsDs[i].DevMajor),
+ }),
+ NextOff: int64(len(dirents) + 1),
+ Type: uint8(lisafsDs[i].Type),
+ }
+ dirents = append(dirents, dirent)
+ if realChildren != nil {
+ realChildren[name] = struct{}{}
+ }
}
- dirent := vfs.Dirent{
- Name: p9d.Name,
- Ino: d.fs.inoFromQIDPath(p9d.QID.Path),
- NextOff: int64(len(dirents) + 1),
+ } else {
+ p9ds, err := d.readFile.readdir(ctx, off, count)
+ if err != nil {
+ d.handleMu.RUnlock()
+ return nil, err
}
- // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or
- // DMSOCKET.
- switch p9d.Type {
- case p9.TypeSymlink:
- dirent.Type = linux.DT_LNK
- case p9.TypeDir:
- dirent.Type = linux.DT_DIR
- default:
- dirent.Type = linux.DT_REG
+ if len(p9ds) == 0 {
+ d.handleMu.RUnlock()
+ break
}
- dirents = append(dirents, dirent)
- if realChildren != nil {
- realChildren[p9d.Name] = struct{}{}
+ for _, p9d := range p9ds {
+ if p9d.Name == "." || p9d.Name == ".." {
+ continue
+ }
+ dirent := vfs.Dirent{
+ Name: p9d.Name,
+ Ino: d.fs.inoFromQIDPath(p9d.QID.Path),
+ NextOff: int64(len(dirents) + 1),
+ }
+ // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or
+ // DMSOCKET.
+ switch p9d.Type {
+ case p9.TypeSymlink:
+ dirent.Type = linux.DT_LNK
+ case p9.TypeDir:
+ dirent.Type = linux.DT_DIR
+ default:
+ dirent.Type = linux.DT_REG
+ }
+ dirents = append(dirents, dirent)
+ if realChildren != nil {
+ realChildren[p9d.Name] = struct{}{}
+ }
}
+ off = p9ds[len(p9ds)-1].Offset
}
- off = p9ds[len(p9ds)-1].Offset
}
}
// Emit entries for synthetic children.
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 00228c469..23c8b8ce3 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -21,10 +21,12 @@ import (
"sync"
"sync/atomic"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/lisafs"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/host"
"gvisor.dev/gvisor/pkg/sentry/fsmetric"
@@ -53,9 +55,47 @@ func (fs *filesystem) Sync(ctx context.Context) error {
// regardless.
var retErr error
+ if fs.opts.lisaEnabled {
+ // Try accumulating all FDIDs to fsync and fsync then via one RPC as
+ // opposed to making an RPC per FDID. Passing a non-nil accFsyncFDIDs to
+ // dentry.syncCachedFile() and specialFileFD.sync() will cause them to not
+ // make an RPC, instead accumulate syncable FDIDs in the passed slice.
+ accFsyncFDIDs := make([]lisafs.FDID, 0, len(ds)+len(sffds))
+
+ // Sync syncable dentries.
+ for _, d := range ds {
+ if err := d.syncCachedFile(ctx, true /* forFilesystemSync */, &accFsyncFDIDs); err != nil {
+ ctx.Infof("gofer.filesystem.Sync: dentry.syncCachedFile failed: %v", err)
+ if retErr == nil {
+ retErr = err
+ }
+ }
+ }
+
+ // Sync special files, which may be writable but do not use dentry shared
+ // handles (so they won't be synced by the above).
+ for _, sffd := range sffds {
+ if err := sffd.sync(ctx, true /* forFilesystemSync */, &accFsyncFDIDs); err != nil {
+ ctx.Infof("gofer.filesystem.Sync: specialFileFD.sync failed: %v", err)
+ if retErr == nil {
+ retErr = err
+ }
+ }
+ }
+
+ if err := fs.clientLisa.SyncFDs(ctx, accFsyncFDIDs); err != nil {
+ ctx.Infof("gofer.filesystem.Sync: fs.fsyncMultipleFDLisa failed: %v", err)
+ if retErr == nil {
+ retErr = err
+ }
+ }
+
+ return retErr
+ }
+
// Sync syncable dentries.
for _, d := range ds {
- if err := d.syncCachedFile(ctx, true /* forFilesystemSync */); err != nil {
+ if err := d.syncCachedFile(ctx, true /* forFilesystemSync */, nil /* accFsyncFDIDsLisa */); err != nil {
ctx.Infof("gofer.filesystem.Sync: dentry.syncCachedFile failed: %v", err)
if retErr == nil {
retErr = err
@@ -66,7 +106,7 @@ func (fs *filesystem) Sync(ctx context.Context) error {
// Sync special files, which may be writable but do not use dentry shared
// handles (so they won't be synced by the above).
for _, sffd := range sffds {
- if err := sffd.sync(ctx, true /* forFilesystemSync */); err != nil {
+ if err := sffd.sync(ctx, true /* forFilesystemSync */, nil /* accFsyncFDIDsLisa */); err != nil {
ctx.Infof("gofer.filesystem.Sync: specialFileFD.sync failed: %v", err)
if retErr == nil {
retErr = err
@@ -130,7 +170,7 @@ func putDentrySlice(ds *[]*dentry) {
// but dentry slices are allocated lazily, and it's much easier to say "defer
// fs.renameMuRUnlockAndCheckCaching(&ds)" than "defer func() {
// fs.renameMuRUnlockAndCheckCaching(ds) }()" to work around this.
-// +checklocksrelease:fs.renameMu
+// +checklocksreleaseread:fs.renameMu
func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, dsp **[]*dentry) {
fs.renameMu.RUnlock()
if *dsp == nil {
@@ -197,7 +237,13 @@ afterSymlink:
rp.Advance()
return d.parent, followedSymlink, nil
}
- child, err := fs.getChildLocked(ctx, d, name, ds)
+ var child *dentry
+ var err error
+ if fs.opts.lisaEnabled {
+ child, err = fs.getChildAndWalkPathLocked(ctx, d, rp, ds)
+ } else {
+ child, err = fs.getChildLocked(ctx, d, name, ds)
+ }
if err != nil {
return nil, false, err
}
@@ -219,6 +265,99 @@ afterSymlink:
return child, followedSymlink, nil
}
+// Preconditions:
+// * fs.opts.lisaEnabled.
+// * fs.renameMu must be locked.
+// * parent.dirMu must be locked.
+// * parent.isDir().
+// * parent and the dentry at name have been revalidated.
+func (fs *filesystem) getChildAndWalkPathLocked(ctx context.Context, parent *dentry, rp *vfs.ResolvingPath, ds **[]*dentry) (*dentry, error) {
+ // Note that pit is a copy of the iterator that does not affect rp.
+ pit := rp.Pit()
+ first := pit.String()
+ if len(first) > maxFilenameLen {
+ return nil, linuxerr.ENAMETOOLONG
+ }
+ if child, ok := parent.children[first]; ok || parent.isSynthetic() {
+ if child == nil {
+ return nil, linuxerr.ENOENT
+ }
+ return child, nil
+ }
+
+ // Walk as much of the path as possible in 1 RPC.
+ names := []string{first}
+ for pit = pit.Next(); pit.Ok(); pit = pit.Next() {
+ name := pit.String()
+ if name == "." {
+ continue
+ }
+ if name == ".." {
+ break
+ }
+ names = append(names, name)
+ }
+ status, inodes, err := parent.controlFDLisa.WalkMultiple(ctx, names)
+ if err != nil {
+ return nil, err
+ }
+ if len(inodes) == 0 {
+ parent.cacheNegativeLookupLocked(first)
+ return nil, linuxerr.ENOENT
+ }
+
+ // Add the walked inodes into the dentry tree.
+ curParent := parent
+ curParentDirMuLock := func() {
+ if curParent != parent {
+ curParent.dirMu.Lock()
+ }
+ }
+ curParentDirMuUnlock := func() {
+ if curParent != parent {
+ curParent.dirMu.Unlock() // +checklocksforce: locked via curParentDirMuLock().
+ }
+ }
+ var ret *dentry
+ var dentryCreationErr error
+ for i := range inodes {
+ if dentryCreationErr != nil {
+ fs.clientLisa.CloseFDBatched(ctx, inodes[i].ControlFD)
+ continue
+ }
+
+ child, err := fs.newDentryLisa(ctx, &inodes[i])
+ if err != nil {
+ fs.clientLisa.CloseFDBatched(ctx, inodes[i].ControlFD)
+ dentryCreationErr = err
+ continue
+ }
+ curParentDirMuLock()
+ curParent.cacheNewChildLocked(child, names[i])
+ curParentDirMuUnlock()
+ // For now, child has 0 references, so our caller should call
+ // child.checkCachingLocked(). curParent gained a ref so we should also
+ // call curParent.checkCachingLocked() so it can be removed from the cache
+ // if needed. We only do that for the first iteration because all
+ // subsequent parents would have already been added to ds.
+ if i == 0 {
+ *ds = appendDentry(*ds, curParent)
+ }
+ *ds = appendDentry(*ds, child)
+ curParent = child
+ if i == 0 {
+ ret = child
+ }
+ }
+
+ if status == lisafs.WalkComponentDoesNotExist && curParent.isDir() {
+ curParentDirMuLock()
+ curParent.cacheNegativeLookupLocked(names[len(inodes)])
+ curParentDirMuUnlock()
+ }
+ return ret, dentryCreationErr
+}
+
// getChildLocked returns a dentry representing the child of parent with the
// given name. Returns ENOENT if the child doesn't exist.
//
@@ -227,7 +366,7 @@ afterSymlink:
// * parent.dirMu must be locked.
// * parent.isDir().
// * name is not "." or "..".
-// * dentry at name has been revalidated
+// * parent and the dentry at name have been revalidated.
func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
if len(name) > maxFilenameLen {
return nil, linuxerr.ENAMETOOLONG
@@ -239,20 +378,35 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
return child, nil
}
- qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name)
- if err != nil {
- if linuxerr.Equals(linuxerr.ENOENT, err) {
- parent.cacheNegativeLookupLocked(name)
+ var child *dentry
+ if fs.opts.lisaEnabled {
+ childInode, err := parent.controlFDLisa.Walk(ctx, name)
+ if err != nil {
+ if linuxerr.Equals(linuxerr.ENOENT, err) {
+ parent.cacheNegativeLookupLocked(name)
+ }
+ return nil, err
+ }
+ // Create a new dentry representing the file.
+ child, err = fs.newDentryLisa(ctx, childInode)
+ if err != nil {
+ fs.clientLisa.CloseFDBatched(ctx, childInode.ControlFD)
+ return nil, err
+ }
+ } else {
+ qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name)
+ if err != nil {
+ if linuxerr.Equals(linuxerr.ENOENT, err) {
+ parent.cacheNegativeLookupLocked(name)
+ }
+ return nil, err
+ }
+ // Create a new dentry representing the file.
+ child, err = fs.newDentry(ctx, file, qid, attrMask, &attr)
+ if err != nil {
+ file.close(ctx)
+ return nil, err
}
- return nil, err
- }
-
- // Create a new dentry representing the file.
- child, err := fs.newDentry(ctx, file, qid, attrMask, &attr)
- if err != nil {
- file.close(ctx)
- delete(parent.children, name)
- return nil, err
}
parent.cacheNewChildLocked(child, name)
appendNewChildDentry(ds, parent, child)
@@ -328,7 +482,7 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath,
// Preconditions:
// * !rp.Done().
// * For the final path component in rp, !rp.ShouldFollowSymlink().
-func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, createInRemoteDir func(parent *dentry, name string, ds **[]*dentry) error, createInSyntheticDir func(parent *dentry, name string) error) error {
+func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, createInRemoteDir func(parent *dentry, name string, ds **[]*dentry) (*lisafs.Inode, error), createInSyntheticDir func(parent *dentry, name string) error, updateChild func(child *dentry)) error {
var ds *[]*dentry
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
@@ -415,9 +569,26 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
// No cached dentry exists; however, in InteropModeShared there might still be
// an existing file at name. Just attempt the file creation RPC anyways. If a
// file does exist, the RPC will fail with EEXIST like we would have.
- if err := createInRemoteDir(parent, name, &ds); err != nil {
+ lisaInode, err := createInRemoteDir(parent, name, &ds)
+ if err != nil {
return err
}
+ // lisafs may aggresively cache newly created inodes. This has helped reduce
+ // Walk RPCs in practice.
+ if lisaInode != nil {
+ child, err := fs.newDentryLisa(ctx, lisaInode)
+ if err != nil {
+ fs.clientLisa.CloseFDBatched(ctx, lisaInode.ControlFD)
+ return err
+ }
+ parent.cacheNewChildLocked(child, name)
+ appendNewChildDentry(&ds, parent, child)
+
+ // lisafs may update dentry properties upon successful creation.
+ if updateChild != nil {
+ updateChild(child)
+ }
+ }
if fs.opts.interop != InteropModeShared {
if child, ok := parent.children[name]; ok && child == nil {
// Delete the now-stale negative dentry.
@@ -565,7 +736,11 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
return linuxerr.ENOENT
}
} else if child == nil || !child.isSynthetic() {
- err = parent.file.unlinkAt(ctx, name, flags)
+ if fs.opts.lisaEnabled {
+ err = parent.controlFDLisa.UnlinkAt(ctx, name, flags)
+ } else {
+ err = parent.file.unlinkAt(ctx, name, flags)
+ }
if err != nil {
if child != nil {
vfsObj.AbortDeleteDentry(&child.vfsd) // +checklocksforce: see above.
@@ -658,40 +833,43 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa
// LinkAt implements vfs.FilesystemImpl.LinkAt.
func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
- return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, _ **[]*dentry) error {
+ err := fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, ds **[]*dentry) (*lisafs.Inode, error) {
if rp.Mount() != vd.Mount() {
- return linuxerr.EXDEV
+ return nil, linuxerr.EXDEV
}
d := vd.Dentry().Impl().(*dentry)
if d.isDir() {
- return linuxerr.EPERM
+ return nil, linuxerr.EPERM
}
gid := auth.KGID(atomic.LoadUint32(&d.gid))
uid := auth.KUID(atomic.LoadUint32(&d.uid))
mode := linux.FileMode(atomic.LoadUint32(&d.mode))
if err := vfs.MayLink(rp.Credentials(), mode, uid, gid); err != nil {
- return err
+ return nil, err
}
if d.nlink == 0 {
- return linuxerr.ENOENT
+ return nil, linuxerr.ENOENT
}
if d.nlink == math.MaxUint32 {
- return linuxerr.EMLINK
+ return nil, linuxerr.EMLINK
}
- if err := parent.file.link(ctx, d.file, childName); err != nil {
- return err
+ if fs.opts.lisaEnabled {
+ return parent.controlFDLisa.LinkAt(ctx, d.controlFDLisa.ID(), childName)
}
+ return nil, parent.file.link(ctx, d.file, childName)
+ }, nil, nil)
+ if err == nil {
// Success!
- atomic.AddUint32(&d.nlink, 1)
- return nil
- }, nil)
+ vd.Dentry().Impl().(*dentry).incLinks()
+ }
+ return err
}
// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
creds := rp.Credentials()
- return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, ds **[]*dentry) error {
+ return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, ds **[]*dentry) (*lisafs.Inode, error) {
// If the parent is a setgid directory, use the parent's GID
// rather than the caller's and enable setgid.
kgid := creds.EffectiveKGID
@@ -700,23 +878,37 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
kgid = auth.KGID(atomic.LoadUint32(&parent.gid))
mode |= linux.S_ISGID
}
- if _, err := parent.file.mkdir(ctx, name, p9.FileMode(mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid)); err != nil {
- if !opts.ForSyntheticMountpoint || linuxerr.Equals(linuxerr.EEXIST, err) {
- return err
+ var (
+ childDirInode *lisafs.Inode
+ err error
+ )
+ if fs.opts.lisaEnabled {
+ childDirInode, err = parent.controlFDLisa.MkdirAt(ctx, name, mode, lisafs.UID(creds.EffectiveKUID), lisafs.GID(kgid))
+ } else {
+ _, err = parent.file.mkdir(ctx, name, p9.FileMode(mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid))
+ }
+ if err == nil {
+ if fs.opts.interop != InteropModeShared {
+ parent.incLinks()
}
- ctx.Infof("Failed to create remote directory %q: %v; falling back to synthetic directory", name, err)
- parent.createSyntheticChildLocked(&createSyntheticOpts{
- name: name,
- mode: linux.S_IFDIR | opts.Mode,
- kuid: creds.EffectiveKUID,
- kgid: creds.EffectiveKGID,
- })
- *ds = appendDentry(*ds, parent)
+ return childDirInode, nil
+ }
+
+ if !opts.ForSyntheticMountpoint || linuxerr.Equals(linuxerr.EEXIST, err) {
+ return nil, err
}
+ ctx.Infof("Failed to create remote directory %q: %v; falling back to synthetic directory", name, err)
+ parent.createSyntheticChildLocked(&createSyntheticOpts{
+ name: name,
+ mode: linux.S_IFDIR | opts.Mode,
+ kuid: creds.EffectiveKUID,
+ kgid: creds.EffectiveKGID,
+ })
+ *ds = appendDentry(*ds, parent)
if fs.opts.interop != InteropModeShared {
parent.incLinks()
}
- return nil
+ return nil, nil
}, func(parent *dentry, name string) error {
if !opts.ForSyntheticMountpoint {
// Can't create non-synthetic files in synthetic directories.
@@ -730,16 +922,26 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
})
parent.incLinks()
return nil
- })
+ }, nil)
}
// MknodAt implements vfs.FilesystemImpl.MknodAt.
func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
- return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, ds **[]*dentry) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, ds **[]*dentry) (*lisafs.Inode, error) {
creds := rp.Credentials()
- _, err := parent.file.mknod(ctx, name, (p9.FileMode)(opts.Mode), opts.DevMajor, opts.DevMinor, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
- if !linuxerr.Equals(linuxerr.EPERM, err) {
- return err
+ var (
+ childInode *lisafs.Inode
+ err error
+ )
+ if fs.opts.lisaEnabled {
+ childInode, err = parent.controlFDLisa.MknodAt(ctx, name, opts.Mode, lisafs.UID(creds.EffectiveKUID), lisafs.GID(creds.EffectiveKGID), opts.DevMinor, opts.DevMajor)
+ } else {
+ _, err = parent.file.mknod(ctx, name, (p9.FileMode)(opts.Mode), opts.DevMajor, opts.DevMinor, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
+ }
+ if err == nil {
+ return childInode, nil
+ } else if !linuxerr.Equals(linuxerr.EPERM, err) {
+ return nil, err
}
// EPERM means that gofer does not allow creating a socket or pipe. Fallback
@@ -750,10 +952,10 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
switch {
case err == nil:
// Step succeeded, another file exists.
- return linuxerr.EEXIST
+ return nil, linuxerr.EEXIST
case !linuxerr.Equals(linuxerr.ENOENT, err):
// Unexpected error.
- return err
+ return nil, err
}
switch opts.Mode.FileType() {
@@ -766,7 +968,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
endpoint: opts.Endpoint,
})
*ds = appendDentry(*ds, parent)
- return nil
+ return nil, nil
case linux.S_IFIFO:
parent.createSyntheticChildLocked(&createSyntheticOpts{
name: name,
@@ -776,11 +978,11 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize),
})
*ds = appendDentry(*ds, parent)
- return nil
+ return nil, nil
}
// Retain error from gofer if synthetic file cannot be created internally.
- return linuxerr.EPERM
- }, nil)
+ return nil, linuxerr.EPERM
+ }, nil, nil)
}
// OpenAt implements vfs.FilesystemImpl.OpenAt.
@@ -986,6 +1188,23 @@ func (d *dentry) openSocketByConnecting(ctx context.Context, opts *vfs.OpenOptio
if opts.Flags&linux.O_DIRECT != 0 {
return nil, linuxerr.EINVAL
}
+ if d.fs.opts.lisaEnabled {
+ // Note that special value of linux.SockType = 0 is interpreted by lisafs
+ // as "do not care about the socket type". Analogous to p9.AnonymousSocket.
+ sockFD, err := d.controlFDLisa.Connect(ctx, 0 /* sockType */)
+ if err != nil {
+ return nil, err
+ }
+ fd, err := host.NewFD(ctx, kernel.KernelFromContext(ctx).HostMount(), sockFD, &host.NewFDOptions{
+ HaveFlags: true,
+ Flags: opts.Flags,
+ })
+ if err != nil {
+ unix.Close(sockFD)
+ return nil, err
+ }
+ return fd, nil
+ }
fdObj, err := d.file.connect(ctx, p9.AnonymousSocket)
if err != nil {
return nil, err
@@ -998,6 +1217,7 @@ func (d *dentry) openSocketByConnecting(ctx context.Context, opts *vfs.OpenOptio
fdObj.Close()
return nil, err
}
+ // Ownership has been transferred to fd.
fdObj.Release()
return fd, nil
}
@@ -1017,7 +1237,13 @@ func (d *dentry) openSpecialFile(ctx context.Context, mnt *vfs.Mount, opts *vfs.
// since closed its end.
isBlockingOpenOfNamedPipe := d.fileType() == linux.S_IFIFO && opts.Flags&linux.O_NONBLOCK == 0
retry:
- h, err := openHandle(ctx, d.file, ats.MayRead(), ats.MayWrite(), opts.Flags&linux.O_TRUNC != 0)
+ var h handle
+ var err error
+ if d.fs.opts.lisaEnabled {
+ h, err = openHandleLisa(ctx, d.controlFDLisa, ats.MayRead(), ats.MayWrite(), opts.Flags&linux.O_TRUNC != 0)
+ } else {
+ h, err = openHandle(ctx, d.file, ats.MayRead(), ats.MayWrite(), opts.Flags&linux.O_TRUNC != 0)
+ }
if err != nil {
if isBlockingOpenOfNamedPipe && ats == vfs.MayWrite && linuxerr.Equals(linuxerr.ENXIO, err) {
// An attempt to open a named pipe with O_WRONLY|O_NONBLOCK fails
@@ -1061,18 +1287,8 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
}
defer mnt.EndWrite()
- // 9P2000.L's lcreate takes a fid representing the parent directory, and
- // converts it into an open fid representing the created file, so we need
- // to duplicate the directory fid first.
- _, dirfile, err := d.file.walk(ctx, nil)
- if err != nil {
- return nil, err
- }
creds := rp.Credentials()
name := rp.Component()
- // We only want the access mode for creating the file.
- createFlags := p9.OpenFlags(opts.Flags) & p9.OpenFlagsModeMask
-
// If the parent is a setgid directory, use the parent's GID rather
// than the caller's.
kgid := creds.EffectiveKGID
@@ -1080,51 +1296,87 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
kgid = auth.KGID(atomic.LoadUint32(&d.gid))
}
- fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, createFlags, p9.FileMode(opts.Mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid))
- if err != nil {
- dirfile.close(ctx)
- return nil, err
- }
- // Then we need to walk to the file we just created to get a non-open fid
- // representing it, and to get its metadata. This must use d.file since, as
- // explained above, dirfile was invalidated by dirfile.Create().
- _, nonOpenFile, attrMask, attr, err := d.file.walkGetAttrOne(ctx, name)
- if err != nil {
- openFile.close(ctx)
- if fdobj != nil {
- fdobj.Close()
+ var child *dentry
+ var openP9File p9file
+ openLisaFD := lisafs.InvalidFDID
+ openHostFD := int32(-1)
+ if d.fs.opts.lisaEnabled {
+ ino, openFD, hostFD, err := d.controlFDLisa.OpenCreateAt(ctx, name, opts.Flags&linux.O_ACCMODE, opts.Mode, lisafs.UID(creds.EffectiveKUID), lisafs.GID(kgid))
+ if err != nil {
+ return nil, err
+ }
+ openHostFD = int32(hostFD)
+ openLisaFD = openFD
+
+ child, err = d.fs.newDentryLisa(ctx, &ino)
+ if err != nil {
+ d.fs.clientLisa.CloseFDBatched(ctx, ino.ControlFD)
+ d.fs.clientLisa.CloseFDBatched(ctx, openFD)
+ if hostFD >= 0 {
+ unix.Close(hostFD)
+ }
+ return nil, err
+ }
+ } else {
+ // 9P2000.L's lcreate takes a fid representing the parent directory, and
+ // converts it into an open fid representing the created file, so we need
+ // to duplicate the directory fid first.
+ _, dirfile, err := d.file.walk(ctx, nil)
+ if err != nil {
+ return nil, err
+ }
+ // We only want the access mode for creating the file.
+ createFlags := p9.OpenFlags(opts.Flags) & p9.OpenFlagsModeMask
+
+ fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, createFlags, p9.FileMode(opts.Mode), (p9.UID)(creds.EffectiveKUID), p9.GID(kgid))
+ if err != nil {
+ dirfile.close(ctx)
+ return nil, err
+ }
+ // Then we need to walk to the file we just created to get a non-open fid
+ // representing it, and to get its metadata. This must use d.file since, as
+ // explained above, dirfile was invalidated by dirfile.Create().
+ _, nonOpenFile, attrMask, attr, err := d.file.walkGetAttrOne(ctx, name)
+ if err != nil {
+ openFile.close(ctx)
+ if fdobj != nil {
+ fdobj.Close()
+ }
+ return nil, err
+ }
+
+ // Construct the new dentry.
+ child, err = d.fs.newDentry(ctx, nonOpenFile, createQID, attrMask, &attr)
+ if err != nil {
+ nonOpenFile.close(ctx)
+ openFile.close(ctx)
+ if fdobj != nil {
+ fdobj.Close()
+ }
+ return nil, err
}
- return nil, err
- }
- // Construct the new dentry.
- child, err := d.fs.newDentry(ctx, nonOpenFile, createQID, attrMask, &attr)
- if err != nil {
- nonOpenFile.close(ctx)
- openFile.close(ctx)
if fdobj != nil {
- fdobj.Close()
+ openHostFD = int32(fdobj.Release())
}
- return nil, err
+ openP9File = openFile
}
// Incorporate the fid that was opened by lcreate.
useRegularFileFD := child.fileType() == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD
if useRegularFileFD {
- openFD := int32(-1)
- if fdobj != nil {
- openFD = int32(fdobj.Release())
- }
child.handleMu.Lock()
if vfs.MayReadFileWithOpenFlags(opts.Flags) {
- child.readFile = openFile
- if fdobj != nil {
- child.readFD = openFD
- child.mmapFD = openFD
+ child.readFile = openP9File
+ child.readFDLisa = d.fs.clientLisa.NewFD(openLisaFD)
+ if openHostFD != -1 {
+ child.readFD = openHostFD
+ child.mmapFD = openHostFD
}
}
if vfs.MayWriteFileWithOpenFlags(opts.Flags) {
- child.writeFile = openFile
- child.writeFD = openFD
+ child.writeFile = openP9File
+ child.writeFDLisa = d.fs.clientLisa.NewFD(openLisaFD)
+ child.writeFD = openHostFD
}
child.handleMu.Unlock()
}
@@ -1146,11 +1398,9 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
childVFSFD = &fd.vfsfd
} else {
h := handle{
- file: openFile,
- fd: -1,
- }
- if fdobj != nil {
- h.fd = int32(fdobj.Release())
+ file: openP9File,
+ fdLisa: d.fs.clientLisa.NewFD(openLisaFD),
+ fd: openHostFD,
}
fd, err := newSpecialFileFD(h, mnt, child, opts.Flags)
if err != nil {
@@ -1304,7 +1554,12 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
// Update the remote filesystem.
if !renamed.isSynthetic() {
- if err := renamed.file.rename(ctx, newParent.file, newName); err != nil {
+ if fs.opts.lisaEnabled {
+ err = renamed.controlFDLisa.RenameTo(ctx, newParent.controlFDLisa.ID(), newName)
+ } else {
+ err = renamed.file.rename(ctx, newParent.file, newName)
+ }
+ if err != nil {
vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD)
return err
}
@@ -1315,7 +1570,12 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if replaced.isDir() {
flags = linux.AT_REMOVEDIR
}
- if err := newParent.file.unlinkAt(ctx, newName, flags); err != nil {
+ if fs.opts.lisaEnabled {
+ err = newParent.controlFDLisa.UnlinkAt(ctx, newName, flags)
+ } else {
+ err = newParent.file.unlinkAt(ctx, newName, flags)
+ }
+ if err != nil {
vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD)
return err
}
@@ -1431,6 +1691,28 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu
for d.isSynthetic() {
d = d.parent
}
+ if fs.opts.lisaEnabled {
+ var statFS lisafs.StatFS
+ if err := d.controlFDLisa.StatFSTo(ctx, &statFS); err != nil {
+ return linux.Statfs{}, err
+ }
+ if statFS.NameLength > maxFilenameLen {
+ statFS.NameLength = maxFilenameLen
+ }
+ return linux.Statfs{
+ // This is primarily for distinguishing a gofer file system in
+ // tests. Testing is important, so instead of defining
+ // something completely random, use a standard value.
+ Type: linux.V9FS_MAGIC,
+ BlockSize: statFS.BlockSize,
+ Blocks: statFS.Blocks,
+ BlocksFree: statFS.BlocksFree,
+ BlocksAvailable: statFS.BlocksAvailable,
+ Files: statFS.Files,
+ FilesFree: statFS.FilesFree,
+ NameLength: statFS.NameLength,
+ }, nil
+ }
fsstat, err := d.file.statFS(ctx)
if err != nil {
return linux.Statfs{}, err
@@ -1456,11 +1738,21 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu
// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
- return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, _ **[]*dentry) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, ds **[]*dentry) (*lisafs.Inode, error) {
creds := rp.Credentials()
+ if fs.opts.lisaEnabled {
+ return parent.controlFDLisa.SymlinkAt(ctx, name, target, lisafs.UID(creds.EffectiveKUID), lisafs.GID(creds.EffectiveKGID))
+ }
_, err := parent.file.symlink(ctx, target, name, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
- return err
- }, nil)
+ return nil, err
+ }, nil, func(child *dentry) {
+ if fs.opts.interop != InteropModeShared {
+ // lisafs caches the symlink target on creation. In practice, this
+ // helps avoid a lot of ReadLink RPCs.
+ child.haveTarget = true
+ child.target = target
+ }
+ })
}
// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
@@ -1505,7 +1797,7 @@ func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, si
if err != nil {
return nil, err
}
- return d.listXattr(ctx, rp.Credentials(), size)
+ return d.listXattr(ctx, size)
}
// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt.
@@ -1612,6 +1904,9 @@ func (fs *filesystem) MountOptions() string {
if fs.opts.overlayfsStaleRead {
optsKV = append(optsKV, mopt{moptOverlayfsStaleRead, nil})
}
+ if fs.opts.lisaEnabled {
+ optsKV = append(optsKV, mopt{moptLisafs, nil})
+ }
opts := make([]string, 0, len(optsKV))
for _, opt := range optsKV {
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 43440ec19..b98825e26 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -48,6 +48,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/lisafs"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
refs_vfs1 "gvisor.dev/gvisor/pkg/refs"
@@ -83,6 +84,7 @@ const (
moptForcePageCache = "force_page_cache"
moptLimitHostFDTranslation = "limit_host_fd_translation"
moptOverlayfsStaleRead = "overlayfs_stale_read"
+ moptLisafs = "lisafs"
)
// Valid values for the "cache" mount option.
@@ -118,6 +120,10 @@ type filesystem struct {
// client is the client used by this filesystem. client is immutable.
client *p9.Client `state:"nosave"`
+ // clientLisa is the client used for communicating with the server when
+ // lisafs is enabled. lisafsCient is immutable.
+ clientLisa *lisafs.Client `state:"nosave"`
+
// clock is a realtime clock used to set timestamps in file operations.
clock ktime.Clock
@@ -161,6 +167,12 @@ type filesystem struct {
inoMu sync.Mutex `state:"nosave"`
inoByQIDPath map[uint64]uint64 `state:"nosave"`
+ // inoByKey is the same as inoByQIDPath but only used by lisafs. It helps
+ // identify inodes based on the device ID and host inode number provided
+ // by the gofer process. It is not preserved across checkpoint/restore for
+ // the same reason as above. inoByKey is protected by inoMu.
+ inoByKey map[inoKey]uint64 `state:"nosave"`
+
// lastIno is the last inode number assigned to a file. lastIno is accessed
// using atomic memory operations.
lastIno uint64
@@ -214,6 +226,10 @@ type filesystemOptions struct {
// way that application FDs representing "special files" such as sockets
// do. Note that this disables client caching and mmap for regular files.
regularFilesUseSpecialFileFD bool
+
+ // lisaEnabled indicates whether the client will use lisafs protocol to
+ // communicate with the server instead of 9P.
+ lisaEnabled bool
}
// InteropMode controls the client's interaction with other remote filesystem
@@ -427,6 +443,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
delete(mopts, moptOverlayfsStaleRead)
fsopts.overlayfsStaleRead = true
}
+ if lisafs, ok := mopts[moptLisafs]; ok {
+ delete(mopts, moptLisafs)
+ fsopts.lisaEnabled, err = strconv.ParseBool(lisafs)
+ if err != nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid lisafs option: %s", lisafs)
+ return nil, nil, linuxerr.EINVAL
+ }
+ }
// fsopts.regularFilesUseSpecialFileFD can only be enabled by specifying
// "cache=none".
@@ -458,44 +482,83 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
syncableDentries: make(map[*dentry]struct{}),
specialFileFDs: make(map[*specialFileFD]struct{}),
inoByQIDPath: make(map[uint64]uint64),
+ inoByKey: make(map[inoKey]uint64),
}
fs.vfsfs.Init(vfsObj, &fstype, fs)
+ if err := fs.initClientAndRoot(ctx); err != nil {
+ fs.vfsfs.DecRef(ctx)
+ return nil, nil, err
+ }
+
+ return &fs.vfsfs, &fs.root.vfsd, nil
+}
+
+func (fs *filesystem) initClientAndRoot(ctx context.Context) error {
+ var err error
+ if fs.opts.lisaEnabled {
+ var rootInode *lisafs.Inode
+ rootInode, err = fs.initClientLisa(ctx)
+ if err != nil {
+ return err
+ }
+ fs.root, err = fs.newDentryLisa(ctx, rootInode)
+ if err != nil {
+ fs.clientLisa.CloseFDBatched(ctx, rootInode.ControlFD)
+ }
+ } else {
+ fs.root, err = fs.initClient(ctx)
+ }
+
+ // Set the root's reference count to 2. One reference is returned to the
+ // caller, and the other is held by fs to prevent the root from being "cached"
+ // and subsequently evicted.
+ if err == nil {
+ fs.root.refs = 2
+ }
+ return err
+}
+
+func (fs *filesystem) initClientLisa(ctx context.Context) (*lisafs.Inode, error) {
+ sock, err := unet.NewSocket(fs.opts.fd)
+ if err != nil {
+ return nil, err
+ }
+
+ var rootInode *lisafs.Inode
+ ctx.UninterruptibleSleepStart(false)
+ fs.clientLisa, rootInode, err = lisafs.NewClient(sock, fs.opts.aname)
+ ctx.UninterruptibleSleepFinish(false)
+ return rootInode, err
+}
+
+func (fs *filesystem) initClient(ctx context.Context) (*dentry, error) {
// Connect to the server.
if err := fs.dial(ctx); err != nil {
- return nil, nil, err
+ return nil, err
}
// Perform attach to obtain the filesystem root.
ctx.UninterruptibleSleepStart(false)
- attached, err := fs.client.Attach(fsopts.aname)
+ attached, err := fs.client.Attach(fs.opts.aname)
ctx.UninterruptibleSleepFinish(false)
if err != nil {
- fs.vfsfs.DecRef(ctx)
- return nil, nil, err
+ return nil, err
}
attachFile := p9file{attached}
qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask())
if err != nil {
attachFile.close(ctx)
- fs.vfsfs.DecRef(ctx)
- return nil, nil, err
+ return nil, err
}
// Construct the root dentry.
root, err := fs.newDentry(ctx, attachFile, qid, attrMask, &attr)
if err != nil {
attachFile.close(ctx)
- fs.vfsfs.DecRef(ctx)
- return nil, nil, err
+ return nil, err
}
- // Set the root's reference count to 2. One reference is returned to the
- // caller, and the other is held by fs to prevent the root from being "cached"
- // and subsequently evicted.
- root.refs = 2
- fs.root = root
-
- return &fs.vfsfs, &root.vfsd, nil
+ return root, nil
}
func getFDFromMountOptionsMap(ctx context.Context, mopts map[string]string) (int, error) {
@@ -613,7 +676,11 @@ func (fs *filesystem) Release(ctx context.Context) {
if !fs.iopts.LeakConnection {
// Close the connection to the server. This implicitly clunks all fids.
- fs.client.Close()
+ if fs.opts.lisaEnabled {
+ fs.clientLisa.Close()
+ } else {
+ fs.client.Close()
+ }
}
fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
@@ -644,6 +711,23 @@ func (d *dentry) releaseSyntheticRecursiveLocked(ctx context.Context) {
}
}
+// inoKey is the key used to identify the inode backed by this dentry.
+//
+// +stateify savable
+type inoKey struct {
+ ino uint64
+ devMinor uint32
+ devMajor uint32
+}
+
+func inoKeyFromStat(stat *linux.Statx) inoKey {
+ return inoKey{
+ ino: stat.Ino,
+ devMinor: stat.DevMinor,
+ devMajor: stat.DevMajor,
+ }
+}
+
// dentry implements vfs.DentryImpl.
//
// +stateify savable
@@ -674,6 +758,9 @@ type dentry struct {
// qidPath is the p9.QID.Path for this file. qidPath is immutable.
qidPath uint64
+ // inoKey is used to identify this dentry's inode.
+ inoKey inoKey
+
// file is the unopened p9.File that backs this dentry. file is immutable.
//
// If file.isNil(), this dentry represents a synthetic file, i.e. a file
@@ -681,6 +768,14 @@ type dentry struct {
// only files that can be synthetic are sockets, pipes, and directories.
file p9file `state:"nosave"`
+ // controlFDLisa is used by lisafs to perform path based operations on this
+ // dentry.
+ //
+ // if !controlFDLisa.Ok(), this dentry represents a synthetic file, i.e. a
+ // file that does not exist on the remote filesystem. As of this writing, the
+ // only files that can be synthetic are sockets, pipes, and directories.
+ controlFDLisa lisafs.ClientFD `state:"nosave"`
+
// If deleted is non-zero, the file represented by this dentry has been
// deleted. deleted is accessed using atomic memory operations.
deleted uint32
@@ -791,12 +886,14 @@ type dentry struct {
// always either -1 or equal to readFD; if !writeFile.isNil() (the file has
// been opened for writing), it is additionally either -1 or equal to
// writeFD.
- handleMu sync.RWMutex `state:"nosave"`
- readFile p9file `state:"nosave"`
- writeFile p9file `state:"nosave"`
- readFD int32 `state:"nosave"`
- writeFD int32 `state:"nosave"`
- mmapFD int32 `state:"nosave"`
+ handleMu sync.RWMutex `state:"nosave"`
+ readFile p9file `state:"nosave"`
+ writeFile p9file `state:"nosave"`
+ readFDLisa lisafs.ClientFD `state:"nosave"`
+ writeFDLisa lisafs.ClientFD `state:"nosave"`
+ readFD int32 `state:"nosave"`
+ writeFD int32 `state:"nosave"`
+ mmapFD int32 `state:"nosave"`
dataMu sync.RWMutex `state:"nosave"`
@@ -920,6 +1017,79 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma
return d, nil
}
+func (fs *filesystem) newDentryLisa(ctx context.Context, ino *lisafs.Inode) (*dentry, error) {
+ if ino.Stat.Mask&linux.STATX_TYPE == 0 {
+ ctx.Warningf("can't create gofer.dentry without file type")
+ return nil, linuxerr.EIO
+ }
+ if ino.Stat.Mode&linux.FileTypeMask == linux.ModeRegular && ino.Stat.Mask&linux.STATX_SIZE == 0 {
+ ctx.Warningf("can't create regular file gofer.dentry without file size")
+ return nil, linuxerr.EIO
+ }
+
+ inoKey := inoKeyFromStat(&ino.Stat)
+ d := &dentry{
+ fs: fs,
+ inoKey: inoKey,
+ ino: fs.inoFromKey(inoKey),
+ mode: uint32(ino.Stat.Mode),
+ uid: uint32(fs.opts.dfltuid),
+ gid: uint32(fs.opts.dfltgid),
+ blockSize: hostarch.PageSize,
+ readFD: -1,
+ writeFD: -1,
+ mmapFD: -1,
+ controlFDLisa: fs.clientLisa.NewFD(ino.ControlFD),
+ }
+
+ d.pf.dentry = d
+ if ino.Stat.Mask&linux.STATX_UID != 0 {
+ d.uid = dentryUIDFromLisaUID(lisafs.UID(ino.Stat.UID))
+ }
+ if ino.Stat.Mask&linux.STATX_GID != 0 {
+ d.gid = dentryGIDFromLisaGID(lisafs.GID(ino.Stat.GID))
+ }
+ if ino.Stat.Mask&linux.STATX_SIZE != 0 {
+ d.size = ino.Stat.Size
+ }
+ if ino.Stat.Blksize != 0 {
+ d.blockSize = ino.Stat.Blksize
+ }
+ if ino.Stat.Mask&linux.STATX_ATIME != 0 {
+ d.atime = dentryTimestampFromLisa(ino.Stat.Atime)
+ }
+ if ino.Stat.Mask&linux.STATX_MTIME != 0 {
+ d.mtime = dentryTimestampFromLisa(ino.Stat.Mtime)
+ }
+ if ino.Stat.Mask&linux.STATX_CTIME != 0 {
+ d.ctime = dentryTimestampFromLisa(ino.Stat.Ctime)
+ }
+ if ino.Stat.Mask&linux.STATX_BTIME != 0 {
+ d.btime = dentryTimestampFromLisa(ino.Stat.Btime)
+ }
+ if ino.Stat.Mask&linux.STATX_NLINK != 0 {
+ d.nlink = ino.Stat.Nlink
+ }
+ d.vfsd.Init(d)
+ refsvfs2.Register(d)
+ fs.syncMu.Lock()
+ fs.syncableDentries[d] = struct{}{}
+ fs.syncMu.Unlock()
+ return d, nil
+}
+
+func (fs *filesystem) inoFromKey(key inoKey) uint64 {
+ fs.inoMu.Lock()
+ defer fs.inoMu.Unlock()
+
+ if ino, ok := fs.inoByKey[key]; ok {
+ return ino
+ }
+ ino := fs.nextIno()
+ fs.inoByKey[key] = ino
+ return ino
+}
+
func (fs *filesystem) inoFromQIDPath(qidPath uint64) uint64 {
fs.inoMu.Lock()
defer fs.inoMu.Unlock()
@@ -936,7 +1106,7 @@ func (fs *filesystem) nextIno() uint64 {
}
func (d *dentry) isSynthetic() bool {
- return d.file.isNil()
+ return !d.isControlFileOk()
}
func (d *dentry) cachedMetadataAuthoritative() bool {
@@ -986,6 +1156,50 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) {
}
}
+// updateFromLisaStatLocked is called to update d's metadata after an update
+// from the remote filesystem.
+// Precondition: d.metadataMu must be locked.
+// +checklocks:d.metadataMu
+func (d *dentry) updateFromLisaStatLocked(stat *linux.Statx) {
+ if stat.Mask&linux.STATX_TYPE != 0 {
+ if got, want := stat.Mode&linux.FileTypeMask, d.fileType(); uint32(got) != want {
+ panic(fmt.Sprintf("gofer.dentry file type changed from %#o to %#o", want, got))
+ }
+ }
+ if stat.Mask&linux.STATX_MODE != 0 {
+ atomic.StoreUint32(&d.mode, uint32(stat.Mode))
+ }
+ if stat.Mask&linux.STATX_UID != 0 {
+ atomic.StoreUint32(&d.uid, dentryUIDFromLisaUID(lisafs.UID(stat.UID)))
+ }
+ if stat.Mask&linux.STATX_GID != 0 {
+ atomic.StoreUint32(&d.uid, dentryGIDFromLisaGID(lisafs.GID(stat.GID)))
+ }
+ if stat.Blksize != 0 {
+ atomic.StoreUint32(&d.blockSize, stat.Blksize)
+ }
+ // Don't override newer client-defined timestamps with old server-defined
+ // ones.
+ if stat.Mask&linux.STATX_ATIME != 0 && atomic.LoadUint32(&d.atimeDirty) == 0 {
+ atomic.StoreInt64(&d.atime, dentryTimestampFromLisa(stat.Atime))
+ }
+ if stat.Mask&linux.STATX_MTIME != 0 && atomic.LoadUint32(&d.mtimeDirty) == 0 {
+ atomic.StoreInt64(&d.mtime, dentryTimestampFromLisa(stat.Mtime))
+ }
+ if stat.Mask&linux.STATX_CTIME != 0 {
+ atomic.StoreInt64(&d.ctime, dentryTimestampFromLisa(stat.Ctime))
+ }
+ if stat.Mask&linux.STATX_BTIME != 0 {
+ atomic.StoreInt64(&d.btime, dentryTimestampFromLisa(stat.Btime))
+ }
+ if stat.Mask&linux.STATX_NLINK != 0 {
+ atomic.StoreUint32(&d.nlink, stat.Nlink)
+ }
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ d.updateSizeLocked(stat.Size)
+ }
+}
+
// Preconditions: !d.isSynthetic().
// Preconditions: d.metadataMu is locked.
// +checklocks:d.metadataMu
@@ -995,6 +1209,9 @@ func (d *dentry) refreshSizeLocked(ctx context.Context) error {
if d.writeFD < 0 {
d.handleMu.RUnlock()
// Ask the gofer if we don't have a host FD.
+ if d.fs.opts.lisaEnabled {
+ return d.updateFromStatLisaLocked(ctx, nil)
+ }
return d.updateFromGetattrLocked(ctx, p9file{})
}
@@ -1014,6 +1231,9 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error {
// updating stale attributes in d.updateFromP9AttrsLocked().
d.metadataMu.Lock()
defer d.metadataMu.Unlock()
+ if d.fs.opts.lisaEnabled {
+ return d.updateFromStatLisaLocked(ctx, nil)
+ }
return d.updateFromGetattrLocked(ctx, p9file{})
}
@@ -1021,6 +1241,45 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error {
// * !d.isSynthetic().
// * d.metadataMu is locked.
// +checklocks:d.metadataMu
+func (d *dentry) updateFromStatLisaLocked(ctx context.Context, fdLisa *lisafs.ClientFD) error {
+ handleMuRLocked := false
+ if fdLisa == nil {
+ // Use open FDs in preferenece to the control FD. This may be significantly
+ // more efficient in some implementations. Prefer a writable FD over a
+ // readable one since some filesystem implementations may update a writable
+ // FD's metadata after writes, without making metadata updates immediately
+ // visible to read-only FDs representing the same file.
+ d.handleMu.RLock()
+ switch {
+ case d.writeFDLisa.Ok():
+ fdLisa = &d.writeFDLisa
+ handleMuRLocked = true
+ case d.readFDLisa.Ok():
+ fdLisa = &d.readFDLisa
+ handleMuRLocked = true
+ default:
+ fdLisa = &d.controlFDLisa
+ d.handleMu.RUnlock()
+ }
+ }
+
+ var stat linux.Statx
+ err := fdLisa.StatTo(ctx, &stat)
+ if handleMuRLocked {
+ // handleMu must be released before updateFromLisaStatLocked().
+ d.handleMu.RUnlock() // +checklocksforce: complex case.
+ }
+ if err != nil {
+ return err
+ }
+ d.updateFromLisaStatLocked(&stat)
+ return nil
+}
+
+// Preconditions:
+// * !d.isSynthetic().
+// * d.metadataMu is locked.
+// +checklocks:d.metadataMu
func (d *dentry) updateFromGetattrLocked(ctx context.Context, file p9file) error {
handleMuRLocked := false
if file.isNil() {
@@ -1160,6 +1419,13 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
}
}
+ // failureMask indicates which attributes could not be set on the remote
+ // filesystem. p9 returns an error if any of the attributes could not be set
+ // but that leads to inconsistency as the server could have set a few
+ // attributes successfully but a later failure will cause the successful ones
+ // to not be updated in the dentry cache.
+ var failureMask uint32
+ var failureErr error
if !d.isSynthetic() {
if stat.Mask != 0 {
if stat.Mask&linux.STATX_SIZE != 0 {
@@ -1169,35 +1435,50 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
// the remote file has been truncated).
d.dataMu.Lock()
}
- if err := d.file.setAttr(ctx, p9.SetAttrMask{
- Permissions: stat.Mask&linux.STATX_MODE != 0,
- UID: stat.Mask&linux.STATX_UID != 0,
- GID: stat.Mask&linux.STATX_GID != 0,
- Size: stat.Mask&linux.STATX_SIZE != 0,
- ATime: stat.Mask&linux.STATX_ATIME != 0,
- MTime: stat.Mask&linux.STATX_MTIME != 0,
- ATimeNotSystemTime: stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW,
- MTimeNotSystemTime: stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec != linux.UTIME_NOW,
- }, p9.SetAttr{
- Permissions: p9.FileMode(stat.Mode),
- UID: p9.UID(stat.UID),
- GID: p9.GID(stat.GID),
- Size: stat.Size,
- ATimeSeconds: uint64(stat.Atime.Sec),
- ATimeNanoSeconds: uint64(stat.Atime.Nsec),
- MTimeSeconds: uint64(stat.Mtime.Sec),
- MTimeNanoSeconds: uint64(stat.Mtime.Nsec),
- }); err != nil {
- if stat.Mask&linux.STATX_SIZE != 0 {
- d.dataMu.Unlock() // +checklocksforce: locked conditionally above
+ if d.fs.opts.lisaEnabled {
+ var err error
+ failureMask, failureErr, err = d.controlFDLisa.SetStat(ctx, stat)
+ if err != nil {
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ d.dataMu.Unlock() // +checklocksforce: locked conditionally above
+ }
+ return err
+ }
+ } else {
+ if err := d.file.setAttr(ctx, p9.SetAttrMask{
+ Permissions: stat.Mask&linux.STATX_MODE != 0,
+ UID: stat.Mask&linux.STATX_UID != 0,
+ GID: stat.Mask&linux.STATX_GID != 0,
+ Size: stat.Mask&linux.STATX_SIZE != 0,
+ ATime: stat.Mask&linux.STATX_ATIME != 0,
+ MTime: stat.Mask&linux.STATX_MTIME != 0,
+ ATimeNotSystemTime: stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW,
+ MTimeNotSystemTime: stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec != linux.UTIME_NOW,
+ }, p9.SetAttr{
+ Permissions: p9.FileMode(stat.Mode),
+ UID: p9.UID(stat.UID),
+ GID: p9.GID(stat.GID),
+ Size: stat.Size,
+ ATimeSeconds: uint64(stat.Atime.Sec),
+ ATimeNanoSeconds: uint64(stat.Atime.Nsec),
+ MTimeSeconds: uint64(stat.Mtime.Sec),
+ MTimeNanoSeconds: uint64(stat.Mtime.Nsec),
+ }); err != nil {
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ d.dataMu.Unlock() // +checklocksforce: locked conditionally above
+ }
+ return err
}
- return err
}
if stat.Mask&linux.STATX_SIZE != 0 {
- // d.size should be kept up to date, and privatized
- // copy-on-write mappings of truncated pages need to be
- // invalidated, even if InteropModeShared is in effect.
- d.updateSizeAndUnlockDataMuLocked(stat.Size) // +checklocksforce: locked conditionally above
+ if failureMask&linux.STATX_SIZE == 0 {
+ // d.size should be kept up to date, and privatized
+ // copy-on-write mappings of truncated pages need to be
+ // invalidated, even if InteropModeShared is in effect.
+ d.updateSizeAndUnlockDataMuLocked(stat.Size) // +checklocksforce: locked conditionally above
+ } else {
+ d.dataMu.Unlock() // +checklocksforce: locked conditionally above
+ }
}
}
if d.fs.opts.interop == InteropModeShared {
@@ -1208,13 +1489,13 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
return nil
}
}
- if stat.Mask&linux.STATX_MODE != 0 {
+ if stat.Mask&linux.STATX_MODE != 0 && failureMask&linux.STATX_MODE == 0 {
atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode))
}
- if stat.Mask&linux.STATX_UID != 0 {
+ if stat.Mask&linux.STATX_UID != 0 && failureMask&linux.STATX_UID == 0 {
atomic.StoreUint32(&d.uid, stat.UID)
}
- if stat.Mask&linux.STATX_GID != 0 {
+ if stat.Mask&linux.STATX_GID != 0 && failureMask&linux.STATX_GID == 0 {
atomic.StoreUint32(&d.gid, stat.GID)
}
// Note that stat.Atime.Nsec and stat.Mtime.Nsec can't be UTIME_NOW because
@@ -1222,15 +1503,19 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
// stat.Mtime to client-local timestamps above, and if
// !d.cachedMetadataAuthoritative() then we returned after calling
// d.file.setAttr(). For the same reason, now must have been initialized.
- if stat.Mask&linux.STATX_ATIME != 0 {
+ if stat.Mask&linux.STATX_ATIME != 0 && failureMask&linux.STATX_ATIME == 0 {
atomic.StoreInt64(&d.atime, stat.Atime.ToNsec())
atomic.StoreUint32(&d.atimeDirty, 0)
}
- if stat.Mask&linux.STATX_MTIME != 0 {
+ if stat.Mask&linux.STATX_MTIME != 0 && failureMask&linux.STATX_MTIME == 0 {
atomic.StoreInt64(&d.mtime, stat.Mtime.ToNsec())
atomic.StoreUint32(&d.mtimeDirty, 0)
}
atomic.StoreInt64(&d.ctime, now)
+ if failureMask != 0 {
+ // Setting some attribute failed on the remote filesystem.
+ return failureErr
+ }
return nil
}
@@ -1310,7 +1595,10 @@ func (d *dentry) checkXattrPermissions(creds *auth.Credentials, name string, ats
// (b/148380782). Allow all other extended attributes to be passed through
// to the remote filesystem. This is inconsistent with Linux's 9p client,
// but consistent with other filesystems (e.g. FUSE).
- if strings.HasPrefix(name, linux.XATTR_SECURITY_PREFIX) || strings.HasPrefix(name, linux.XATTR_SYSTEM_PREFIX) {
+ //
+ // NOTE(b/202533394): Also disallow "trusted" namespace for now. This is
+ // consistent with the VFS1 gofer client.
+ if strings.HasPrefix(name, linux.XATTR_SECURITY_PREFIX) || strings.HasPrefix(name, linux.XATTR_SYSTEM_PREFIX) || strings.HasPrefix(name, linux.XATTR_TRUSTED_PREFIX) {
return linuxerr.EOPNOTSUPP
}
mode := linux.FileMode(atomic.LoadUint32(&d.mode))
@@ -1346,6 +1634,20 @@ func dentryGIDFromP9GID(gid p9.GID) uint32 {
return uint32(gid)
}
+func dentryUIDFromLisaUID(uid lisafs.UID) uint32 {
+ if !uid.Ok() {
+ return uint32(auth.OverflowUID)
+ }
+ return uint32(uid)
+}
+
+func dentryGIDFromLisaGID(gid lisafs.GID) uint32 {
+ if !gid.Ok() {
+ return uint32(auth.OverflowGID)
+ }
+ return uint32(gid)
+}
+
// IncRef implements vfs.DentryImpl.IncRef.
func (d *dentry) IncRef() {
// d.refs may be 0 if d.fs.renameMu is locked, which serializes against
@@ -1654,15 +1956,24 @@ func (d *dentry) destroyLocked(ctx context.Context) {
d.dirty.RemoveAll()
}
d.dataMu.Unlock()
- // Clunk open fids and close open host FDs.
- if !d.readFile.isNil() {
- _ = d.readFile.close(ctx)
- }
- if !d.writeFile.isNil() && d.readFile != d.writeFile {
- _ = d.writeFile.close(ctx)
+ if d.fs.opts.lisaEnabled {
+ if d.readFDLisa.Ok() && d.readFDLisa.ID() != d.writeFDLisa.ID() {
+ d.readFDLisa.CloseBatched(ctx)
+ }
+ if d.writeFDLisa.Ok() {
+ d.writeFDLisa.CloseBatched(ctx)
+ }
+ } else {
+ // Clunk open fids and close open host FDs.
+ if !d.readFile.isNil() {
+ _ = d.readFile.close(ctx)
+ }
+ if !d.writeFile.isNil() && d.readFile != d.writeFile {
+ _ = d.writeFile.close(ctx)
+ }
+ d.readFile = p9file{}
+ d.writeFile = p9file{}
}
- d.readFile = p9file{}
- d.writeFile = p9file{}
if d.readFD >= 0 {
_ = unix.Close(int(d.readFD))
}
@@ -1674,7 +1985,7 @@ func (d *dentry) destroyLocked(ctx context.Context) {
d.mmapFD = -1
d.handleMu.Unlock()
- if !d.file.isNil() {
+ if d.isControlFileOk() {
// Note that it's possible that d.atimeDirty or d.mtimeDirty are true,
// i.e. client and server timestamps may differ (because e.g. a client
// write was serviced by the page cache, and only written back to the
@@ -1683,10 +1994,16 @@ func (d *dentry) destroyLocked(ctx context.Context) {
// instantiated for the same file would remain coherent. Unfortunately,
// this turns out to be too expensive in many cases, so for now we
// don't do this.
- if err := d.file.close(ctx); err != nil {
- log.Warningf("gofer.dentry.destroyLocked: failed to close file: %v", err)
+
+ // Close the control FD.
+ if d.fs.opts.lisaEnabled {
+ d.controlFDLisa.CloseBatched(ctx)
+ } else {
+ if err := d.file.close(ctx); err != nil {
+ log.Warningf("gofer.dentry.destroyLocked: failed to close file: %v", err)
+ }
+ d.file = p9file{}
}
- d.file = p9file{}
// Remove d from the set of syncable dentries.
d.fs.syncMu.Lock()
@@ -1712,10 +2029,29 @@ func (d *dentry) setDeleted() {
atomic.StoreUint32(&d.deleted, 1)
}
-func (d *dentry) listXattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) {
- if d.file.isNil() {
+func (d *dentry) isControlFileOk() bool {
+ if d.fs.opts.lisaEnabled {
+ return d.controlFDLisa.Ok()
+ }
+ return !d.file.isNil()
+}
+
+func (d *dentry) isReadFileOk() bool {
+ if d.fs.opts.lisaEnabled {
+ return d.readFDLisa.Ok()
+ }
+ return !d.readFile.isNil()
+}
+
+func (d *dentry) listXattr(ctx context.Context, size uint64) ([]string, error) {
+ if !d.isControlFileOk() {
return nil, nil
}
+
+ if d.fs.opts.lisaEnabled {
+ return d.controlFDLisa.ListXattr(ctx, size)
+ }
+
xattrMap, err := d.file.listXattr(ctx, size)
if err != nil {
return nil, err
@@ -1728,32 +2064,41 @@ func (d *dentry) listXattr(ctx context.Context, creds *auth.Credentials, size ui
}
func (d *dentry) getXattr(ctx context.Context, creds *auth.Credentials, opts *vfs.GetXattrOptions) (string, error) {
- if d.file.isNil() {
+ if !d.isControlFileOk() {
return "", linuxerr.ENODATA
}
if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayRead); err != nil {
return "", err
}
+ if d.fs.opts.lisaEnabled {
+ return d.controlFDLisa.GetXattr(ctx, opts.Name, opts.Size)
+ }
return d.file.getXattr(ctx, opts.Name, opts.Size)
}
func (d *dentry) setXattr(ctx context.Context, creds *auth.Credentials, opts *vfs.SetXattrOptions) error {
- if d.file.isNil() {
+ if !d.isControlFileOk() {
return linuxerr.EPERM
}
if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayWrite); err != nil {
return err
}
+ if d.fs.opts.lisaEnabled {
+ return d.controlFDLisa.SetXattr(ctx, opts.Name, opts.Value, opts.Flags)
+ }
return d.file.setXattr(ctx, opts.Name, opts.Value, opts.Flags)
}
func (d *dentry) removeXattr(ctx context.Context, creds *auth.Credentials, name string) error {
- if d.file.isNil() {
+ if !d.isControlFileOk() {
return linuxerr.EPERM
}
if err := d.checkXattrPermissions(creds, name, vfs.MayWrite); err != nil {
return err
}
+ if d.fs.opts.lisaEnabled {
+ return d.controlFDLisa.RemoveXattr(ctx, name)
+ }
return d.file.removeXattr(ctx, name)
}
@@ -1765,19 +2110,30 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
// O_TRUNC).
if !trunc {
d.handleMu.RLock()
- if (!read || !d.readFile.isNil()) && (!write || !d.writeFile.isNil()) {
+ var canReuseCurHandle bool
+ if d.fs.opts.lisaEnabled {
+ canReuseCurHandle = (!read || d.readFDLisa.Ok()) && (!write || d.writeFDLisa.Ok())
+ } else {
+ canReuseCurHandle = (!read || !d.readFile.isNil()) && (!write || !d.writeFile.isNil())
+ }
+ d.handleMu.RUnlock()
+ if canReuseCurHandle {
// Current handles are sufficient.
- d.handleMu.RUnlock()
return nil
}
- d.handleMu.RUnlock()
}
var fdsToCloseArr [2]int32
fdsToClose := fdsToCloseArr[:0]
invalidateTranslations := false
d.handleMu.Lock()
- if (read && d.readFile.isNil()) || (write && d.writeFile.isNil()) || trunc {
+ var needNewHandle bool
+ if d.fs.opts.lisaEnabled {
+ needNewHandle = (read && !d.readFDLisa.Ok()) || (write && !d.writeFDLisa.Ok()) || trunc
+ } else {
+ needNewHandle = (read && d.readFile.isNil()) || (write && d.writeFile.isNil()) || trunc
+ }
+ if needNewHandle {
// Get a new handle. If this file has been opened for both reading and
// writing, try to get a single handle that is usable for both:
//
@@ -1786,9 +2142,21 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
//
// - NOTE(b/141991141): Some filesystems may not ensure coherence
// between multiple handles for the same file.
- openReadable := !d.readFile.isNil() || read
- openWritable := !d.writeFile.isNil() || write
- h, err := openHandle(ctx, d.file, openReadable, openWritable, trunc)
+ var (
+ openReadable bool
+ openWritable bool
+ h handle
+ err error
+ )
+ if d.fs.opts.lisaEnabled {
+ openReadable = d.readFDLisa.Ok() || read
+ openWritable = d.writeFDLisa.Ok() || write
+ h, err = openHandleLisa(ctx, d.controlFDLisa, openReadable, openWritable, trunc)
+ } else {
+ openReadable = !d.readFile.isNil() || read
+ openWritable = !d.writeFile.isNil() || write
+ h, err = openHandle(ctx, d.file, openReadable, openWritable, trunc)
+ }
if linuxerr.Equals(linuxerr.EACCES, err) && (openReadable != read || openWritable != write) {
// It may not be possible to use a single handle for both
// reading and writing, since permissions on the file may have
@@ -1798,7 +2166,11 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
ctx.Debugf("gofer.dentry.ensureSharedHandle: bifurcating read/write handles for dentry %p", d)
openReadable = read
openWritable = write
- h, err = openHandle(ctx, d.file, openReadable, openWritable, trunc)
+ if d.fs.opts.lisaEnabled {
+ h, err = openHandleLisa(ctx, d.controlFDLisa, openReadable, openWritable, trunc)
+ } else {
+ h, err = openHandle(ctx, d.file, openReadable, openWritable, trunc)
+ }
}
if err != nil {
d.handleMu.Unlock()
@@ -1860,9 +2232,16 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
// previously opened for reading (without an FD), then existing
// translations of the file may use the internal page cache;
// invalidate those mappings.
- if d.writeFile.isNil() {
- invalidateTranslations = !d.readFile.isNil()
- atomic.StoreInt32(&d.mmapFD, h.fd)
+ if d.fs.opts.lisaEnabled {
+ if !d.writeFDLisa.Ok() {
+ invalidateTranslations = d.readFDLisa.Ok()
+ atomic.StoreInt32(&d.mmapFD, h.fd)
+ }
+ } else {
+ if d.writeFile.isNil() {
+ invalidateTranslations = !d.readFile.isNil()
+ atomic.StoreInt32(&d.mmapFD, h.fd)
+ }
}
} else if openWritable && d.writeFD < 0 {
atomic.StoreInt32(&d.writeFD, h.fd)
@@ -1889,24 +2268,45 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
atomic.StoreInt32(&d.mmapFD, -1)
}
- // Switch to new fids.
- var oldReadFile p9file
- if openReadable {
- oldReadFile = d.readFile
- d.readFile = h.file
- }
- var oldWriteFile p9file
- if openWritable {
- oldWriteFile = d.writeFile
- d.writeFile = h.file
- }
- // NOTE(b/141991141): Clunk old fids before making new fids visible (by
- // unlocking d.handleMu).
- if !oldReadFile.isNil() {
- oldReadFile.close(ctx)
- }
- if !oldWriteFile.isNil() && oldReadFile != oldWriteFile {
- oldWriteFile.close(ctx)
+ // Switch to new fids/FDs.
+ if d.fs.opts.lisaEnabled {
+ oldReadFD := lisafs.InvalidFDID
+ if openReadable {
+ oldReadFD = d.readFDLisa.ID()
+ d.readFDLisa = h.fdLisa
+ }
+ oldWriteFD := lisafs.InvalidFDID
+ if openWritable {
+ oldWriteFD = d.writeFDLisa.ID()
+ d.writeFDLisa = h.fdLisa
+ }
+ // NOTE(b/141991141): Close old FDs before making new fids visible (by
+ // unlocking d.handleMu).
+ if oldReadFD.Ok() {
+ d.fs.clientLisa.CloseFDBatched(ctx, oldReadFD)
+ }
+ if oldWriteFD.Ok() && oldReadFD != oldWriteFD {
+ d.fs.clientLisa.CloseFDBatched(ctx, oldWriteFD)
+ }
+ } else {
+ var oldReadFile p9file
+ if openReadable {
+ oldReadFile = d.readFile
+ d.readFile = h.file
+ }
+ var oldWriteFile p9file
+ if openWritable {
+ oldWriteFile = d.writeFile
+ d.writeFile = h.file
+ }
+ // NOTE(b/141991141): Clunk old fids before making new fids visible (by
+ // unlocking d.handleMu).
+ if !oldReadFile.isNil() {
+ oldReadFile.close(ctx)
+ }
+ if !oldWriteFile.isNil() && oldReadFile != oldWriteFile {
+ oldWriteFile.close(ctx)
+ }
}
}
d.handleMu.Unlock()
@@ -1930,27 +2330,29 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
// Preconditions: d.handleMu must be locked.
func (d *dentry) readHandleLocked() handle {
return handle{
- file: d.readFile,
- fd: d.readFD,
+ fdLisa: d.readFDLisa,
+ file: d.readFile,
+ fd: d.readFD,
}
}
// Preconditions: d.handleMu must be locked.
func (d *dentry) writeHandleLocked() handle {
return handle{
- file: d.writeFile,
- fd: d.writeFD,
+ fdLisa: d.writeFDLisa,
+ file: d.writeFile,
+ fd: d.writeFD,
}
}
func (d *dentry) syncRemoteFile(ctx context.Context) error {
d.handleMu.RLock()
defer d.handleMu.RUnlock()
- return d.syncRemoteFileLocked(ctx)
+ return d.syncRemoteFileLocked(ctx, nil /* accFsyncFDIDsLisa */)
}
// Preconditions: d.handleMu must be locked.
-func (d *dentry) syncRemoteFileLocked(ctx context.Context) error {
+func (d *dentry) syncRemoteFileLocked(ctx context.Context, accFsyncFDIDsLisa *[]lisafs.FDID) error {
// If we have a host FD, fsyncing it is likely to be faster than an fsync
// RPC. Prefer syncing write handles over read handles, since some remote
// filesystem implementations may not sync changes made through write
@@ -1961,7 +2363,13 @@ func (d *dentry) syncRemoteFileLocked(ctx context.Context) error {
ctx.UninterruptibleSleepFinish(false)
return err
}
- if !d.writeFile.isNil() {
+ if d.fs.opts.lisaEnabled && d.writeFDLisa.Ok() {
+ if accFsyncFDIDsLisa != nil {
+ *accFsyncFDIDsLisa = append(*accFsyncFDIDsLisa, d.writeFDLisa.ID())
+ return nil
+ }
+ return d.writeFDLisa.Sync(ctx)
+ } else if !d.fs.opts.lisaEnabled && !d.writeFile.isNil() {
return d.writeFile.fsync(ctx)
}
if d.readFD >= 0 {
@@ -1970,13 +2378,19 @@ func (d *dentry) syncRemoteFileLocked(ctx context.Context) error {
ctx.UninterruptibleSleepFinish(false)
return err
}
- if !d.readFile.isNil() {
+ if d.fs.opts.lisaEnabled && d.readFDLisa.Ok() {
+ if accFsyncFDIDsLisa != nil {
+ *accFsyncFDIDsLisa = append(*accFsyncFDIDsLisa, d.readFDLisa.ID())
+ return nil
+ }
+ return d.readFDLisa.Sync(ctx)
+ } else if !d.fs.opts.lisaEnabled && !d.readFile.isNil() {
return d.readFile.fsync(ctx)
}
return nil
}
-func (d *dentry) syncCachedFile(ctx context.Context, forFilesystemSync bool) error {
+func (d *dentry) syncCachedFile(ctx context.Context, forFilesystemSync bool, accFsyncFDIDsLisa *[]lisafs.FDID) error {
d.handleMu.RLock()
defer d.handleMu.RUnlock()
h := d.writeHandleLocked()
@@ -1989,7 +2403,7 @@ func (d *dentry) syncCachedFile(ctx context.Context, forFilesystemSync bool) err
return err
}
}
- if err := d.syncRemoteFileLocked(ctx); err != nil {
+ if err := d.syncRemoteFileLocked(ctx, accFsyncFDIDsLisa); err != nil {
if !forFilesystemSync {
return err
}
@@ -2046,18 +2460,33 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu
d := fd.dentry()
const validMask = uint32(linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_SIZE | linux.STATX_BLOCKS | linux.STATX_BTIME)
if !d.cachedMetadataAuthoritative() && opts.Mask&validMask != 0 && opts.Sync != linux.AT_STATX_DONT_SYNC {
- // Use specialFileFD.handle.file for the getattr if available, for the
- // same reason that we try to use open file handles in
- // dentry.updateFromGetattrLocked().
- var file p9file
- if sffd, ok := fd.vfsfd.Impl().(*specialFileFD); ok {
- file = sffd.handle.file
- }
- d.metadataMu.Lock()
- err := d.updateFromGetattrLocked(ctx, file)
- d.metadataMu.Unlock()
- if err != nil {
- return linux.Statx{}, err
+ if d.fs.opts.lisaEnabled {
+ // Use specialFileFD.handle.fileLisa for the Stat if available, for the
+ // same reason that we try to use open FD in updateFromStatLisaLocked().
+ var fdLisa *lisafs.ClientFD
+ if sffd, ok := fd.vfsfd.Impl().(*specialFileFD); ok {
+ fdLisa = &sffd.handle.fdLisa
+ }
+ d.metadataMu.Lock()
+ err := d.updateFromStatLisaLocked(ctx, fdLisa)
+ d.metadataMu.Unlock()
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ } else {
+ // Use specialFileFD.handle.file for the getattr if available, for the
+ // same reason that we try to use open file handles in
+ // dentry.updateFromGetattrLocked().
+ var file p9file
+ if sffd, ok := fd.vfsfd.Impl().(*specialFileFD); ok {
+ file = sffd.handle.file
+ }
+ d.metadataMu.Lock()
+ err := d.updateFromGetattrLocked(ctx, file)
+ d.metadataMu.Unlock()
+ if err != nil {
+ return linux.Statx{}, err
+ }
}
}
var stat linux.Statx
@@ -2078,7 +2507,7 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions)
// ListXattr implements vfs.FileDescriptionImpl.ListXattr.
func (fd *fileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) {
- return fd.dentry().listXattr(ctx, auth.CredentialsFromContext(ctx), size)
+ return fd.dentry().listXattr(ctx, size)
}
// GetXattr implements vfs.FileDescriptionImpl.GetXattr.
diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go
index 806392d50..d5cc73f33 100644
--- a/pkg/sentry/fsimpl/gofer/gofer_test.go
+++ b/pkg/sentry/fsimpl/gofer/gofer_test.go
@@ -33,6 +33,7 @@ func TestDestroyIdempotent(t *testing.T) {
},
syncableDentries: make(map[*dentry]struct{}),
inoByQIDPath: make(map[uint64]uint64),
+ inoByKey: make(map[inoKey]uint64),
}
attr := &p9.Attr{
diff --git a/pkg/sentry/fsimpl/gofer/handle.go b/pkg/sentry/fsimpl/gofer/handle.go
index 02540a754..394aecd62 100644
--- a/pkg/sentry/fsimpl/gofer/handle.go
+++ b/pkg/sentry/fsimpl/gofer/handle.go
@@ -17,6 +17,7 @@ package gofer
import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/lisafs"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/hostfd"
@@ -26,10 +27,13 @@ import (
// handle represents a remote "open file descriptor", consisting of an opened
// fid (p9.File) and optionally a host file descriptor.
//
+// If lisafs is being used, fdLisa points to an open file on the server.
+//
// These are explicitly not savable.
type handle struct {
- file p9file
- fd int32 // -1 if unavailable
+ fdLisa lisafs.ClientFD
+ file p9file
+ fd int32 // -1 if unavailable
}
// Preconditions: read || write.
@@ -65,13 +69,47 @@ func openHandle(ctx context.Context, file p9file, read, write, trunc bool) (hand
}, nil
}
+// Preconditions: read || write.
+func openHandleLisa(ctx context.Context, fdLisa lisafs.ClientFD, read, write, trunc bool) (handle, error) {
+ var flags uint32
+ switch {
+ case read && write:
+ flags = unix.O_RDWR
+ case read:
+ flags = unix.O_RDONLY
+ case write:
+ flags = unix.O_WRONLY
+ default:
+ panic("tried to open unreadable and unwritable handle")
+ }
+ if trunc {
+ flags |= unix.O_TRUNC
+ }
+ openFD, hostFD, err := fdLisa.OpenAt(ctx, flags)
+ if err != nil {
+ return handle{fd: -1}, err
+ }
+ h := handle{
+ fdLisa: fdLisa.Client().NewFD(openFD),
+ fd: int32(hostFD),
+ }
+ return h, nil
+}
+
func (h *handle) isOpen() bool {
+ if h.fdLisa.Client() != nil {
+ return h.fdLisa.Ok()
+ }
return !h.file.isNil()
}
func (h *handle) close(ctx context.Context) {
- h.file.close(ctx)
- h.file = p9file{}
+ if h.fdLisa.Client() != nil {
+ h.fdLisa.CloseBatched(ctx)
+ } else {
+ h.file.close(ctx)
+ h.file = p9file{}
+ }
if h.fd >= 0 {
unix.Close(int(h.fd))
h.fd = -1
@@ -89,19 +127,27 @@ func (h *handle) readToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offs
return n, err
}
if dsts.NumBlocks() == 1 && !dsts.Head().NeedSafecopy() {
- n, err := h.file.readAt(ctx, dsts.Head().ToSlice(), offset)
- return uint64(n), err
+ if h.fdLisa.Client() != nil {
+ return h.fdLisa.Read(ctx, dsts.Head().ToSlice(), offset)
+ }
+ return h.file.readAt(ctx, dsts.Head().ToSlice(), offset)
}
// Buffer the read since p9.File.ReadAt() takes []byte.
buf := make([]byte, dsts.NumBytes())
- n, err := h.file.readAt(ctx, buf, offset)
+ var n uint64
+ var err error
+ if h.fdLisa.Client() != nil {
+ n, err = h.fdLisa.Read(ctx, buf, offset)
+ } else {
+ n, err = h.file.readAt(ctx, buf, offset)
+ }
if n == 0 {
return 0, err
}
if cp, cperr := safemem.CopySeq(dsts, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:n]))); cperr != nil {
return cp, cperr
}
- return uint64(n), err
+ return n, err
}
func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) {
@@ -115,8 +161,10 @@ func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, o
return n, err
}
if srcs.NumBlocks() == 1 && !srcs.Head().NeedSafecopy() {
- n, err := h.file.writeAt(ctx, srcs.Head().ToSlice(), offset)
- return uint64(n), err
+ if h.fdLisa.Client() != nil {
+ return h.fdLisa.Write(ctx, srcs.Head().ToSlice(), offset)
+ }
+ return h.file.writeAt(ctx, srcs.Head().ToSlice(), offset)
}
// Buffer the write since p9.File.WriteAt() takes []byte.
buf := make([]byte, srcs.NumBytes())
@@ -124,12 +172,18 @@ func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, o
if cp == 0 {
return 0, cperr
}
- n, err := h.file.writeAt(ctx, buf[:cp], offset)
+ var n uint64
+ var err error
+ if h.fdLisa.Client() != nil {
+ n, err = h.fdLisa.Write(ctx, buf[:cp], offset)
+ } else {
+ n, err = h.file.writeAt(ctx, buf[:cp], offset)
+ }
// err takes precedence over cperr.
if err != nil {
- return uint64(n), err
+ return n, err
}
- return uint64(n), cperr
+ return n, cperr
}
type handleReadWriter struct {
diff --git a/pkg/sentry/fsimpl/gofer/p9file.go b/pkg/sentry/fsimpl/gofer/p9file.go
index 5a3ddfc9d..0d97b60fd 100644
--- a/pkg/sentry/fsimpl/gofer/p9file.go
+++ b/pkg/sentry/fsimpl/gofer/p9file.go
@@ -141,18 +141,18 @@ func (f p9file) open(ctx context.Context, flags p9.OpenFlags) (*fd.FD, p9.QID, u
return fdobj, qid, iounit, err
}
-func (f p9file) readAt(ctx context.Context, p []byte, offset uint64) (int, error) {
+func (f p9file) readAt(ctx context.Context, p []byte, offset uint64) (uint64, error) {
ctx.UninterruptibleSleepStart(false)
n, err := f.file.ReadAt(p, offset)
ctx.UninterruptibleSleepFinish(false)
- return n, err
+ return uint64(n), err
}
-func (f p9file) writeAt(ctx context.Context, p []byte, offset uint64) (int, error) {
+func (f p9file) writeAt(ctx context.Context, p []byte, offset uint64) (uint64, error) {
ctx.UninterruptibleSleepStart(false)
n, err := f.file.WriteAt(p, offset)
ctx.UninterruptibleSleepFinish(false)
- return n, err
+ return uint64(n), err
}
func (f p9file) fsync(ctx context.Context) error {
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index 947dbe05f..874f9873d 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -98,6 +98,12 @@ func (fd *regularFileFD) OnClose(ctx context.Context) error {
}
d.handleMu.RLock()
defer d.handleMu.RUnlock()
+ if d.fs.opts.lisaEnabled {
+ if !d.writeFDLisa.Ok() {
+ return nil
+ }
+ return d.writeFDLisa.Flush(ctx)
+ }
if d.writeFile.isNil() {
return nil
}
@@ -110,6 +116,9 @@ func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint
return d.doAllocate(ctx, offset, length, func() error {
d.handleMu.RLock()
defer d.handleMu.RUnlock()
+ if d.fs.opts.lisaEnabled {
+ return d.writeFDLisa.Allocate(ctx, mode, offset, length)
+ }
return d.writeFile.allocate(ctx, p9.ToAllocateMode(mode), offset, length)
})
}
@@ -282,8 +291,19 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off
// changes to the host.
if newMode := vfs.ClearSUIDAndSGID(oldMode); newMode != oldMode {
atomic.StoreUint32(&d.mode, newMode)
- if err := d.file.setAttr(ctx, p9.SetAttrMask{Permissions: true}, p9.SetAttr{Permissions: p9.FileMode(newMode)}); err != nil {
- return 0, offset, err
+ if d.fs.opts.lisaEnabled {
+ stat := linux.Statx{Mask: linux.STATX_MODE, Mode: uint16(newMode)}
+ failureMask, failureErr, err := d.controlFDLisa.SetStat(ctx, &stat)
+ if err != nil {
+ return 0, offset, err
+ }
+ if failureMask != 0 {
+ return 0, offset, failureErr
+ }
+ } else {
+ if err := d.file.setAttr(ctx, p9.SetAttrMask{Permissions: true}, p9.SetAttr{Permissions: p9.FileMode(newMode)}); err != nil {
+ return 0, offset, err
+ }
}
}
}
@@ -677,7 +697,7 @@ func regularFileSeekLocked(ctx context.Context, d *dentry, fdOffset, offset int6
// Sync implements vfs.FileDescriptionImpl.Sync.
func (fd *regularFileFD) Sync(ctx context.Context) error {
- return fd.dentry().syncCachedFile(ctx, false /* lowSyncExpectations */)
+ return fd.dentry().syncCachedFile(ctx, false /* forFilesystemSync */, nil /* accFsyncFDIDsLisa */)
}
// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
diff --git a/pkg/sentry/fsimpl/gofer/revalidate.go b/pkg/sentry/fsimpl/gofer/revalidate.go
index 226790a11..5d4009832 100644
--- a/pkg/sentry/fsimpl/gofer/revalidate.go
+++ b/pkg/sentry/fsimpl/gofer/revalidate.go
@@ -15,7 +15,9 @@
package gofer
import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -234,28 +236,54 @@ func (fs *filesystem) revalidateHelper(ctx context.Context, vfsObj *vfs.VirtualF
}
// Lock metadata on all dentries *before* getting attributes for them.
state.lockAllMetadata()
- stats, err := state.start.file.multiGetAttr(ctx, state.names)
- if err != nil {
- return err
+
+ var (
+ stats []p9.FullStat
+ statsLisa []linux.Statx
+ numStats int
+ )
+ if fs.opts.lisaEnabled {
+ var err error
+ statsLisa, err = state.start.controlFDLisa.WalkStat(ctx, state.names)
+ if err != nil {
+ return err
+ }
+ numStats = len(statsLisa)
+ } else {
+ var err error
+ stats, err = state.start.file.multiGetAttr(ctx, state.names)
+ if err != nil {
+ return err
+ }
+ numStats = len(stats)
}
i := -1
for d := state.popFront(); d != nil; d = state.popFront() {
i++
- found := i < len(stats)
+ found := i < numStats
if i == 0 && len(state.names[0]) == 0 {
if found && !d.isSynthetic() {
// First dentry is where the search is starting, just update attributes
// since it cannot be replaced.
- d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) // +checklocksforce: acquired by lockAllMetadata.
+ if fs.opts.lisaEnabled {
+ d.updateFromLisaStatLocked(&statsLisa[i]) // +checklocksforce: acquired by lockAllMetadata.
+ } else {
+ d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) // +checklocksforce: acquired by lockAllMetadata.
+ }
}
d.metadataMu.Unlock() // +checklocksforce: see above.
continue
}
- // Note that synthetic dentries will always fails the comparison check
- // below.
- if !found || d.qidPath != stats[i].QID.Path {
+ // Note that synthetic dentries will always fail this comparison check.
+ var shouldInvalidate bool
+ if fs.opts.lisaEnabled {
+ shouldInvalidate = !found || d.inoKey != inoKeyFromStat(&statsLisa[i])
+ } else {
+ shouldInvalidate = !found || d.qidPath != stats[i].QID.Path
+ }
+ if shouldInvalidate {
d.metadataMu.Unlock() // +checklocksforce: see above.
if !found && d.isSynthetic() {
// We have a synthetic file, and no remote file has arisen to replace
@@ -298,7 +326,11 @@ func (fs *filesystem) revalidateHelper(ctx context.Context, vfsObj *vfs.VirtualF
}
// The file at this path hasn't changed. Just update cached metadata.
- d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) // +checklocksforce: see above.
+ if fs.opts.lisaEnabled {
+ d.updateFromLisaStatLocked(&statsLisa[i]) // +checklocksforce: see above.
+ } else {
+ d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) // +checklocksforce: see above.
+ }
d.metadataMu.Unlock()
}
diff --git a/pkg/sentry/fsimpl/gofer/save_restore.go b/pkg/sentry/fsimpl/gofer/save_restore.go
index 8dcbc61ed..82878c056 100644
--- a/pkg/sentry/fsimpl/gofer/save_restore.go
+++ b/pkg/sentry/fsimpl/gofer/save_restore.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/lisafs"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/safemem"
@@ -112,10 +113,19 @@ func (d *dentry) prepareSaveRecursive(ctx context.Context) error {
return err
}
}
- if !d.readFile.isNil() || !d.writeFile.isNil() {
- d.fs.savedDentryRW[d] = savedDentryRW{
- read: !d.readFile.isNil(),
- write: !d.writeFile.isNil(),
+ if d.fs.opts.lisaEnabled {
+ if d.readFDLisa.Ok() || d.writeFDLisa.Ok() {
+ d.fs.savedDentryRW[d] = savedDentryRW{
+ read: d.readFDLisa.Ok(),
+ write: d.writeFDLisa.Ok(),
+ }
+ }
+ } else {
+ if !d.readFile.isNil() || !d.writeFile.isNil() {
+ d.fs.savedDentryRW[d] = savedDentryRW{
+ read: !d.readFile.isNil(),
+ write: !d.writeFile.isNil(),
+ }
}
}
d.dirMu.Lock()
@@ -177,25 +187,37 @@ func (fs *filesystem) CompleteRestore(ctx context.Context, opts vfs.CompleteRest
return fmt.Errorf("no server FD available for filesystem with unique ID %q", fs.iopts.UniqueID)
}
fs.opts.fd = fd
- if err := fs.dial(ctx); err != nil {
- return err
- }
fs.inoByQIDPath = make(map[uint64]uint64)
+ fs.inoByKey = make(map[inoKey]uint64)
- // Restore the filesystem root.
- ctx.UninterruptibleSleepStart(false)
- attached, err := fs.client.Attach(fs.opts.aname)
- ctx.UninterruptibleSleepFinish(false)
- if err != nil {
- return err
- }
- attachFile := p9file{attached}
- qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask())
- if err != nil {
- return err
- }
- if err := fs.root.restoreFile(ctx, attachFile, qid, attrMask, &attr, &opts); err != nil {
- return err
+ if fs.opts.lisaEnabled {
+ rootInode, err := fs.initClientLisa(ctx)
+ if err != nil {
+ return err
+ }
+ if err := fs.root.restoreFileLisa(ctx, rootInode, &opts); err != nil {
+ return err
+ }
+ } else {
+ if err := fs.dial(ctx); err != nil {
+ return err
+ }
+
+ // Restore the filesystem root.
+ ctx.UninterruptibleSleepStart(false)
+ attached, err := fs.client.Attach(fs.opts.aname)
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ return err
+ }
+ attachFile := p9file{attached}
+ qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask())
+ if err != nil {
+ return err
+ }
+ if err := fs.root.restoreFile(ctx, attachFile, qid, attrMask, &attr, &opts); err != nil {
+ return err
+ }
}
// Restore remaining dentries.
@@ -255,18 +277,18 @@ func (d *dentry) restoreFile(ctx context.Context, file p9file, qid p9.QID, attrM
if d.isRegularFile() {
if opts.ValidateFileSizes {
if !attrMask.Size {
- return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: file size not available", genericDebugPathname(d))
+ return vfs.ErrCorruption{fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: file size not available", genericDebugPathname(d))}
}
if d.size != attr.Size {
- return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: size changed from %d to %d", genericDebugPathname(d), d.size, attr.Size)
+ return vfs.ErrCorruption{fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: size changed from %d to %d", genericDebugPathname(d), d.size, attr.Size)}
}
}
if opts.ValidateFileModificationTimestamps {
if !attrMask.MTime {
- return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime not available", genericDebugPathname(d))
+ return vfs.ErrCorruption{fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime not available", genericDebugPathname(d))}
}
if want := dentryTimestampFromP9(attr.MTimeSeconds, attr.MTimeNanoSeconds); d.mtime != want {
- return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime changed from %+v to %+v", genericDebugPathname(d), linux.NsecToStatxTimestamp(d.mtime), linux.NsecToStatxTimestamp(want))
+ return vfs.ErrCorruption{fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime changed from %+v to %+v", genericDebugPathname(d), linux.NsecToStatxTimestamp(d.mtime), linux.NsecToStatxTimestamp(want))}
}
}
}
@@ -283,6 +305,55 @@ func (d *dentry) restoreFile(ctx context.Context, file p9file, qid p9.QID, attrM
return nil
}
+func (d *dentry) restoreFileLisa(ctx context.Context, inode *lisafs.Inode, opts *vfs.CompleteRestoreOptions) error {
+ d.controlFDLisa = d.fs.clientLisa.NewFD(inode.ControlFD)
+
+ // Gofers do not preserve inoKey across checkpoint/restore, so:
+ //
+ // - We must assume that the remote filesystem did not change in a way that
+ // would invalidate dentries, since we can't revalidate dentries by
+ // checking inoKey.
+ //
+ // - We need to associate the new inoKey with the existing d.ino.
+ d.inoKey = inoKeyFromStat(&inode.Stat)
+ d.fs.inoMu.Lock()
+ d.fs.inoByKey[d.inoKey] = d.ino
+ d.fs.inoMu.Unlock()
+
+ // Check metadata stability before updating metadata.
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ if d.isRegularFile() {
+ if opts.ValidateFileSizes {
+ if inode.Stat.Mask&linux.STATX_SIZE != 0 {
+ return vfs.ErrCorruption{fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: file size not available", genericDebugPathname(d))}
+ }
+ if d.size != inode.Stat.Size {
+ return vfs.ErrCorruption{fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: size changed from %d to %d", genericDebugPathname(d), d.size, inode.Stat.Size)}
+ }
+ }
+ if opts.ValidateFileModificationTimestamps {
+ if inode.Stat.Mask&linux.STATX_MTIME != 0 {
+ return vfs.ErrCorruption{fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime not available", genericDebugPathname(d))}
+ }
+ if want := dentryTimestampFromLisa(inode.Stat.Mtime); d.mtime != want {
+ return vfs.ErrCorruption{fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime changed from %+v to %+v", genericDebugPathname(d), linux.NsecToStatxTimestamp(d.mtime), linux.NsecToStatxTimestamp(want))}
+ }
+ }
+ }
+ if !d.cachedMetadataAuthoritative() {
+ d.updateFromLisaStatLocked(&inode.Stat)
+ }
+
+ if rw, ok := d.fs.savedDentryRW[d]; ok {
+ if err := d.ensureSharedHandle(ctx, rw.read, rw.write, false /* trunc */); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
// Preconditions: d is not synthetic.
func (d *dentry) restoreDescendantsRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error {
for _, child := range d.children {
@@ -305,19 +376,35 @@ func (d *dentry) restoreDescendantsRecursive(ctx context.Context, opts *vfs.Comp
// only be detected by checking filesystem.syncableDentries). d.parent has been
// restored.
func (d *dentry) restoreRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error {
- qid, file, attrMask, attr, err := d.parent.file.walkGetAttrOne(ctx, d.name)
- if err != nil {
- return err
- }
- if err := d.restoreFile(ctx, file, qid, attrMask, &attr, opts); err != nil {
- return err
+ if d.fs.opts.lisaEnabled {
+ inode, err := d.parent.controlFDLisa.Walk(ctx, d.name)
+ if err != nil {
+ return err
+ }
+ if err := d.restoreFileLisa(ctx, inode, opts); err != nil {
+ return err
+ }
+ } else {
+ qid, file, attrMask, attr, err := d.parent.file.walkGetAttrOne(ctx, d.name)
+ if err != nil {
+ return err
+ }
+ if err := d.restoreFile(ctx, file, qid, attrMask, &attr, opts); err != nil {
+ return err
+ }
}
return d.restoreDescendantsRecursive(ctx, opts)
}
func (fd *specialFileFD) completeRestore(ctx context.Context) error {
d := fd.dentry()
- h, err := openHandle(ctx, d.file, fd.vfsfd.IsReadable(), fd.vfsfd.IsWritable(), false /* trunc */)
+ var h handle
+ var err error
+ if d.fs.opts.lisaEnabled {
+ h, err = openHandleLisa(ctx, d.controlFDLisa, fd.vfsfd.IsReadable(), fd.vfsfd.IsWritable(), false /* trunc */)
+ } else {
+ h, err = openHandle(ctx, d.file, fd.vfsfd.IsReadable(), fd.vfsfd.IsWritable(), false /* trunc */)
+ }
if err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go
index fe15f8583..86ab70453 100644
--- a/pkg/sentry/fsimpl/gofer/socket.go
+++ b/pkg/sentry/fsimpl/gofer/socket.go
@@ -59,11 +59,6 @@ func sockTypeToP9(t linux.SockType) (p9.ConnectFlags, bool) {
// BidirectionalConnect implements ConnectableEndpoint.BidirectionalConnect.
func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.ConnectingEndpoint, returnConnect func(transport.Receiver, transport.ConnectedEndpoint)) *syserr.Error {
- cf, ok := sockTypeToP9(ce.Type())
- if !ok {
- return syserr.ErrConnectionRefused
- }
-
// No lock ordering required as only the ConnectingEndpoint has a mutex.
ce.Lock()
@@ -77,7 +72,7 @@ func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.Connec
return syserr.ErrInvalidEndpointState
}
- c, err := e.newConnectedEndpoint(ctx, cf, ce.WaiterQueue())
+ c, err := e.newConnectedEndpoint(ctx, ce.Type(), ce.WaiterQueue())
if err != nil {
ce.Unlock()
return err
@@ -95,7 +90,7 @@ func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.Connec
// UnidirectionalConnect implements
// transport.BoundEndpoint.UnidirectionalConnect.
func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.ConnectedEndpoint, *syserr.Error) {
- c, err := e.newConnectedEndpoint(ctx, p9.DgramSocket, &waiter.Queue{})
+ c, err := e.newConnectedEndpoint(ctx, linux.SOCK_DGRAM, &waiter.Queue{})
if err != nil {
return nil, err
}
@@ -111,25 +106,39 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect
return c, nil
}
-func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFlags, queue *waiter.Queue) (*host.SCMConnectedEndpoint, *syserr.Error) {
- hostFile, err := e.dentry.file.connect(ctx, flags)
- if err != nil {
+func (e *endpoint) newConnectedEndpoint(ctx context.Context, sockType linux.SockType, queue *waiter.Queue) (*host.SCMConnectedEndpoint, *syserr.Error) {
+ if e.dentry.fs.opts.lisaEnabled {
+ hostSockFD, err := e.dentry.controlFDLisa.Connect(ctx, sockType)
+ if err != nil {
+ return nil, syserr.ErrConnectionRefused
+ }
+
+ c, serr := host.NewSCMEndpoint(ctx, hostSockFD, queue, e.path)
+ if serr != nil {
+ unix.Close(hostSockFD)
+ log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v sockType %d: %v", e.dentry.file, sockType, serr)
+ return nil, serr
+ }
+ return c, nil
+ }
+
+ flags, ok := sockTypeToP9(sockType)
+ if !ok {
return nil, syserr.ErrConnectionRefused
}
- // Dup the fd so that the new endpoint can manage its lifetime.
- hostFD, err := unix.Dup(hostFile.FD())
+ hostFile, err := e.dentry.file.connect(ctx, flags)
if err != nil {
- log.Warningf("Could not dup host socket fd %d: %v", hostFile.FD(), err)
- return nil, syserr.FromError(err)
+ return nil, syserr.ErrConnectionRefused
}
- // After duplicating, we no longer need hostFile.
- hostFile.Close()
- c, serr := host.NewSCMEndpoint(ctx, hostFD, queue, e.path)
+ c, serr := host.NewSCMEndpoint(ctx, hostFile.FD(), queue, e.path)
if serr != nil {
- log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.dentry.file, flags, serr)
+ hostFile.Close()
+ log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v sockType %d: %v", e.dentry.file, sockType, serr)
return nil, serr
}
+ // Ownership has been transferred to c.
+ hostFile.Release()
return c, nil
}
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
index a8d47b65b..c568bbfd2 100644
--- a/pkg/sentry/fsimpl/gofer/special_file.go
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/lisafs"
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/safemem"
@@ -149,6 +150,9 @@ func (fd *specialFileFD) OnClose(ctx context.Context) error {
if !fd.vfsfd.IsWritable() {
return nil
}
+ if fs := fd.filesystem(); fs.opts.lisaEnabled {
+ return fd.handle.fdLisa.Flush(ctx)
+ }
return fd.handle.file.flush(ctx)
}
@@ -184,6 +188,9 @@ func (fd *specialFileFD) Allocate(ctx context.Context, mode, offset, length uint
if fd.isRegularFile {
d := fd.dentry()
return d.doAllocate(ctx, offset, length, func() error {
+ if d.fs.opts.lisaEnabled {
+ return fd.handle.fdLisa.Allocate(ctx, mode, offset, length)
+ }
return fd.handle.file.allocate(ctx, p9.ToAllocateMode(mode), offset, length)
})
}
@@ -371,10 +378,10 @@ func (fd *specialFileFD) Seek(ctx context.Context, offset int64, whence int32) (
// Sync implements vfs.FileDescriptionImpl.Sync.
func (fd *specialFileFD) Sync(ctx context.Context) error {
- return fd.sync(ctx, false /* forFilesystemSync */)
+ return fd.sync(ctx, false /* forFilesystemSync */, nil /* accFsyncFDIDsLisa */)
}
-func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool) error {
+func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool, accFsyncFDIDsLisa *[]lisafs.FDID) error {
// Locks to ensure it didn't race with fd.Release().
fd.releaseMu.RLock()
defer fd.releaseMu.RUnlock()
@@ -391,6 +398,13 @@ func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool) error
ctx.UninterruptibleSleepFinish(false)
return err
}
+ if fs := fd.filesystem(); fs.opts.lisaEnabled {
+ if accFsyncFDIDsLisa != nil {
+ *accFsyncFDIDsLisa = append(*accFsyncFDIDsLisa, fd.handle.fdLisa.ID())
+ return nil
+ }
+ return fd.handle.fdLisa.Sync(ctx)
+ }
return fd.handle.file.fsync(ctx)
}()
if err != nil {
diff --git a/pkg/sentry/fsimpl/gofer/symlink.go b/pkg/sentry/fsimpl/gofer/symlink.go
index dbd834c67..27d9be5c4 100644
--- a/pkg/sentry/fsimpl/gofer/symlink.go
+++ b/pkg/sentry/fsimpl/gofer/symlink.go
@@ -35,7 +35,13 @@ func (d *dentry) readlink(ctx context.Context, mnt *vfs.Mount) (string, error) {
return target, nil
}
}
- target, err := d.file.readlink(ctx)
+ var target string
+ var err error
+ if d.fs.opts.lisaEnabled {
+ target, err = d.controlFDLisa.ReadLinkAt(ctx)
+ } else {
+ target, err = d.file.readlink(ctx)
+ }
if d.fs.opts.interop != InteropModeShared {
if err == nil {
d.haveTarget = true
diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go
index 9cbe805b9..07940b225 100644
--- a/pkg/sentry/fsimpl/gofer/time.go
+++ b/pkg/sentry/fsimpl/gofer/time.go
@@ -17,6 +17,7 @@ package gofer
import (
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/vfs"
)
@@ -24,6 +25,10 @@ func dentryTimestampFromP9(s, ns uint64) int64 {
return int64(s*1e9 + ns)
}
+func dentryTimestampFromLisa(t linux.StatxTimestamp) int64 {
+ return t.Sec*1e9 + int64(t.Nsec)
+}
+
// Preconditions: d.cachedMetadataAuthoritative() == true.
func (d *dentry) touchAtime(mnt *vfs.Mount) {
if mnt.Flags.NoATime || mnt.ReadOnly() {
diff --git a/pkg/sentry/fsimpl/mqfs/BUILD b/pkg/sentry/fsimpl/mqfs/BUILD
index e1a38686b..332c9b504 100644
--- a/pkg/sentry/fsimpl/mqfs/BUILD
+++ b/pkg/sentry/fsimpl/mqfs/BUILD
@@ -18,9 +18,9 @@ go_library(
name = "mqfs",
srcs = [
"mqfs.go",
- "root.go",
"queue.go",
"registry.go",
+ "root.go",
"root_inode_refs.go",
],
visibility = ["//pkg/sentry:internal"],
diff --git a/pkg/sentry/fsimpl/mqfs/mqfs.go b/pkg/sentry/fsimpl/mqfs/mqfs.go
index ed559cd13..c2b53c9d0 100644
--- a/pkg/sentry/fsimpl/mqfs/mqfs.go
+++ b/pkg/sentry/fsimpl/mqfs/mqfs.go
@@ -30,6 +30,7 @@ import (
)
const (
+ // Name is the user-visible filesystem name.
Name = "mqueue"
defaultMaxCachedDentries = uint64(1000)
)
@@ -73,7 +74,7 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF
}
impl.fs.MaxCachedDentries = maxCachedDentries
- impl.root.IncRef()
+ impl.fs.VFSFilesystem().IncRef()
return impl.fs.VFSFilesystem(), impl.root.VFSDentry(), nil
}
@@ -109,7 +110,6 @@ type filesystem struct {
func (fs *filesystem) Release(ctx context.Context) {
fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
fs.Filesystem.Release(ctx)
- fs.root.DecRef(ctx)
}
// MountOptions implements vfs.FilesystemImpl.MountOptions.
diff --git a/pkg/sentry/fsimpl/mqfs/registry.go b/pkg/sentry/fsimpl/mqfs/registry.go
index 2c9c79f01..c8fbe4d33 100644
--- a/pkg/sentry/fsimpl/mqfs/registry.go
+++ b/pkg/sentry/fsimpl/mqfs/registry.go
@@ -63,11 +63,12 @@ func NewRegistryImpl(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *
root: &dentry,
}
fs.VFSFilesystem().Init(vfsObj, &FilesystemType{}, fs)
+ vfsfs := fs.VFSFilesystem()
dentry.InitRoot(&fs.Filesystem, fs.newRootInode(ctx, creds))
- dentry.IncRef()
+ defer vfsfs.DecRef(ctx) // NewDisconnectedMount will obtain a ref on success.
- mount, err := vfsObj.NewDisconnectedMount(fs.VFSFilesystem(), dentry.VFSDentry(), &vfs.MountOptions{})
+ mount, err := vfsObj.NewDisconnectedMount(vfsfs, dentry.VFSDentry(), &vfs.MountOptions{})
if err != nil {
return nil, err
}
@@ -129,6 +130,7 @@ func (r *RegistryImpl) Unlink(ctx context.Context, name string) error {
// Destroy implements mq.RegistryImpl.Destroy.
func (r *RegistryImpl) Destroy(ctx context.Context) {
r.root.DecRef(ctx)
+ r.mount.DecRef(ctx)
}
// lookup retreives a kernfs.Inode using a name.
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
index 3b3dcf836..044902241 100644
--- a/pkg/sentry/fsimpl/overlay/filesystem.go
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -86,7 +86,7 @@ func putDentrySlice(ds *[]*dentry) {
// fs.renameMuRUnlockAndCheckDrop(&ds)" than "defer func() {
// fs.renameMuRUnlockAndCheckDrop(ds) }()" to work around this.
//
-// +checklocksrelease:fs.renameMu
+// +checklocksreleaseread:fs.renameMu
func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, dsp **[]*dentry) {
fs.renameMu.RUnlock()
if *dsp == nil {
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index 26d44744b..7b0be9c14 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -268,6 +268,6 @@ func cpuInfoData(k *kernel.Kernel) string {
return buf.String()
}
-func shmData(v uint64) dynamicInode {
+func ipcData(v uint64) dynamicInode {
return newStaticFile(strconv.FormatUint(v, 10))
}
diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go
index 4d3a2f7e6..faec36d8d 100644
--- a/pkg/sentry/fsimpl/proc/tasks_files.go
+++ b/pkg/sentry/fsimpl/proc/tasks_files.go
@@ -262,9 +262,8 @@ var _ dynamicInode = (*meminfoData)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
func (*meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- k := kernel.KernelFromContext(ctx)
- mf := k.MemoryFile()
- mf.UpdateUsage()
+ mf := kernel.KernelFromContext(ctx).MemoryFile()
+ _ = mf.UpdateUsage() // Best effort
snapshot, totalUsage := usage.MemoryAccounting.Copy()
totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage)
anon := snapshot.Anonymous + snapshot.Tmpfs
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index 99f64a9d8..82e2857b3 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -47,9 +47,12 @@ func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k *
"kernel": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
"hostname": fs.newInode(ctx, root, 0444, &hostnameData{}),
"sem": fs.newInode(ctx, root, 0444, newStaticFile(fmt.Sprintf("%d\t%d\t%d\t%d\n", linux.SEMMSL, linux.SEMMNS, linux.SEMOPM, linux.SEMMNI))),
- "shmall": fs.newInode(ctx, root, 0444, shmData(linux.SHMALL)),
- "shmmax": fs.newInode(ctx, root, 0444, shmData(linux.SHMMAX)),
- "shmmni": fs.newInode(ctx, root, 0444, shmData(linux.SHMMNI)),
+ "shmall": fs.newInode(ctx, root, 0444, ipcData(linux.SHMALL)),
+ "shmmax": fs.newInode(ctx, root, 0444, ipcData(linux.SHMMAX)),
+ "shmmni": fs.newInode(ctx, root, 0444, ipcData(linux.SHMMNI)),
+ "msgmni": fs.newInode(ctx, root, 0444, ipcData(linux.MSGMNI)),
+ "msgmax": fs.newInode(ctx, root, 0444, ipcData(linux.MSGMAX)),
+ "msgmnb": fs.newInode(ctx, root, 0444, ipcData(linux.MSGMNB)),
"yama": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
"ptrace_scope": fs.newYAMAPtraceScopeFile(ctx, k, root),
}),
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
index f322d2747..7fcb2d26b 100644
--- a/pkg/sentry/fsimpl/sys/sys.go
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -84,6 +84,18 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
fs.MaxCachedDentries = maxCachedDentries
fs.VFSFilesystem().Init(vfsObj, &fsType, fs)
+ k := kernel.KernelFromContext(ctx)
+ fsDirChildren := make(map[string]kernfs.Inode)
+ // Create an empty directory to serve as the mount point for cgroupfs when
+ // cgroups are available. This emulates Linux behaviour, see
+ // kernel/cgroup.c:cgroup_init(). Note that in Linux, userspace (typically
+ // the init process) is ultimately responsible for actually mounting
+ // cgroupfs, but the kernel creates the mountpoint. For the sentry, the
+ // launcher mounts cgroupfs.
+ if k.CgroupRegistry() != nil {
+ fsDirChildren["cgroup"] = fs.newDir(ctx, creds, defaultSysDirMode, nil)
+ }
+
root := fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{
"block": fs.newDir(ctx, creds, defaultSysDirMode, nil),
"bus": fs.newDir(ctx, creds, defaultSysDirMode, nil),
@@ -97,7 +109,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}),
}),
"firmware": fs.newDir(ctx, creds, defaultSysDirMode, nil),
- "fs": fs.newDir(ctx, creds, defaultSysDirMode, nil),
+ "fs": fs.newDir(ctx, creds, defaultSysDirMode, fsDirChildren),
"kernel": kernelDir(ctx, fs, creds),
"module": fs.newDir(ctx, creds, defaultSysDirMode, nil),
"power": fs.newDir(ctx, creds, defaultSysDirMode, nil),
diff --git a/pkg/sentry/fsimpl/sys/sys_test.go b/pkg/sentry/fsimpl/sys/sys_test.go
index 0a0d914cc..0c46a3a13 100644
--- a/pkg/sentry/fsimpl/sys/sys_test.go
+++ b/pkg/sentry/fsimpl/sys/sys_test.go
@@ -87,3 +87,17 @@ func TestSysRootContainsExpectedEntries(t *testing.T) {
"power": linux.DT_DIR,
})
}
+
+func TestCgroupMountpointExists(t *testing.T) {
+ // Note: The mountpoint is only created if cgroups are available. This is
+ // the VFS2 implementation of sysfs and the test runs with VFS2 enabled, so
+ // we expect to see the mount point unconditionally.
+ s := newTestSystem(t)
+ defer s.Destroy()
+ pop := s.PathOpAtRoot("/fs")
+ s.AssertAllDirentTypes(s.ListDirents(pop), map[string]testutil.DirentType{
+ "cgroup": linux.DT_DIR,
+ })
+ pop = s.PathOpAtRoot("/fs/cgroup")
+ s.AssertAllDirentTypes(s.ListDirents(pop), map[string]testutil.DirentType{ /*empty*/ })
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index 0f2ac6144..453e1aa61 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -95,7 +95,7 @@ type regularFile struct {
func (fs *filesystem) newRegularFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode, parentDir *directory) *inode {
file := &regularFile{
memFile: fs.mfp.MemoryFile(),
- memoryUsageKind: usage.Tmpfs,
+ memoryUsageKind: fs.usage,
seals: linux.F_SEAL_SEAL,
}
file.inode.init(file, fs, kuid, kgid, linux.S_IFREG|mode, parentDir)
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index feafb06e4..f84165aba 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -41,6 +41,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sentry/vfs/memxattr"
"gvisor.dev/gvisor/pkg/sync"
@@ -74,6 +75,10 @@ type filesystem struct {
// filesystem. Immutable.
mopts string
+ // usage is the memory accounting category under which pages backing
+ // files in this filesystem are accounted.
+ usage usage.MemoryKind
+
// mu serializes changes to the Dentry tree.
mu sync.RWMutex `state:"nosave"`
@@ -106,6 +111,10 @@ type FilesystemOpts struct {
// tmpfs filesystem. This allows tmpfs to "impersonate" other
// filesystems, like ramdiskfs and cgroupfs.
FilesystemType vfs.FilesystemType
+
+ // Usage is the memory accounting category under which pages backing files in
+ // the filesystem are accounted.
+ Usage *usage.MemoryKind
}
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
@@ -184,11 +193,16 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
return nil, nil, err
}
clock := time.RealtimeClockFromContext(ctx)
+ memUsage := usage.Tmpfs
+ if tmpfsOpts.Usage != nil {
+ memUsage = *tmpfsOpts.Usage
+ }
fs := filesystem{
mfp: mfp,
clock: clock,
devMinor: devMinor,
mopts: opts.Data,
+ usage: memUsage,
}
fs.vfsfs.Init(vfsObj, newFSType, &fs)
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index 52d47994d..8b059aa7d 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -74,7 +74,7 @@ func putDentrySlice(ds *[]*dentry) {
// but dentry slices are allocated lazily, and it's much easier to say "defer
// fs.renameMuRUnlockAndCheckCaching(&ds)" than "defer func() {
// fs.renameMuRUnlockAndCheckCaching(ds) }()" to work around this.
-// +checklocksrelease:fs.renameMu
+// +checklocksreleaseread:fs.renameMu
func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, ds **[]*dentry) {
fs.renameMu.RUnlock()
if *ds == nil {
diff --git a/pkg/sentry/hostmm/BUILD b/pkg/sentry/hostmm/BUILD
index 66fa1ad40..03c8e2f38 100644
--- a/pkg/sentry/hostmm/BUILD
+++ b/pkg/sentry/hostmm/BUILD
@@ -12,8 +12,7 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/fd",
- "//pkg/hostarch",
+ "//pkg/eventfd",
"//pkg/log",
"@org_golang_x_sys//unix:go_default_library",
],
diff --git a/pkg/sentry/hostmm/hostmm.go b/pkg/sentry/hostmm/hostmm.go
index 285ea9050..5df06a60f 100644
--- a/pkg/sentry/hostmm/hostmm.go
+++ b/pkg/sentry/hostmm/hostmm.go
@@ -21,9 +21,7 @@ import (
"os"
"path"
- "golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/fd"
- "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/eventfd"
"gvisor.dev/gvisor/pkg/log"
)
@@ -54,7 +52,7 @@ func NotifyCurrentMemcgPressureCallback(f func(), level string) (func(), error)
}
defer eventControlFile.Close()
- eventFD, err := newEventFD()
+ eventFD, err := eventfd.Create()
if err != nil {
return nil, err
}
@@ -75,20 +73,11 @@ func NotifyCurrentMemcgPressureCallback(f func(), level string) (func(), error)
const stopVal = 1 << 63
stopCh := make(chan struct{})
go func() { // S/R-SAFE: f provides synchronization if necessary
- rw := fd.NewReadWriter(eventFD.FD())
- var buf [sizeofUint64]byte
for {
- n, err := rw.Read(buf[:])
+ val, err := eventFD.Read()
if err != nil {
- if err == unix.EINTR {
- continue
- }
panic(fmt.Sprintf("failed to read from memory pressure level eventfd: %v", err))
}
- if n != sizeofUint64 {
- panic(fmt.Sprintf("short read from memory pressure level eventfd: got %d bytes, wanted %d", n, sizeofUint64))
- }
- val := hostarch.ByteOrder.Uint64(buf[:])
if val >= stopVal {
// Assume this was due to the notifier's "destructor" (the
// function returned by NotifyCurrentMemcgPressureCallback
@@ -101,30 +90,7 @@ func NotifyCurrentMemcgPressureCallback(f func(), level string) (func(), error)
}
}()
return func() {
- rw := fd.NewReadWriter(eventFD.FD())
- var buf [sizeofUint64]byte
- hostarch.ByteOrder.PutUint64(buf[:], stopVal)
- for {
- n, err := rw.Write(buf[:])
- if err != nil {
- if err == unix.EINTR {
- continue
- }
- panic(fmt.Sprintf("failed to write to memory pressure level eventfd: %v", err))
- }
- if n != sizeofUint64 {
- panic(fmt.Sprintf("short write to memory pressure level eventfd: got %d bytes, wanted %d", n, sizeofUint64))
- }
- break
- }
+ eventFD.Write(stopVal)
<-stopCh
}, nil
}
-
-func newEventFD() (*fd.FD, error) {
- f, _, e := unix.Syscall(unix.SYS_EVENTFD2, 0, 0, 0)
- if e != 0 {
- return nil, fmt.Errorf("failed to create eventfd: %v", e)
- }
- return fd.New(int(f)), nil
-}
diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD
index 5bba9de0b..2363cec5f 100644
--- a/pkg/sentry/inet/BUILD
+++ b/pkg/sentry/inet/BUILD
@@ -1,13 +1,26 @@
load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(
default_visibility = ["//:sandbox"],
licenses = ["notice"],
)
+go_template_instance(
+ name = "atomicptr_netns",
+ out = "atomicptr_netns_unsafe.go",
+ package = "inet",
+ prefix = "Namespace",
+ template = "//pkg/sync/atomicptr:generic_atomicptr",
+ types = {
+ "Value": "Namespace",
+ },
+)
+
go_library(
name = "inet",
srcs = [
+ "atomicptr_netns_unsafe.go",
"context.go",
"inet.go",
"namespace.go",
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index 9f30a7706..f3f16eb7a 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -216,7 +216,6 @@ go_library(
visibility = ["//:sandbox"],
deps = [
":uncaught_signal_go_proto",
- "//pkg/sentry/kernel/ipc",
"//pkg/abi",
"//pkg/abi/linux",
"//pkg/abi/linux/errno",
@@ -257,8 +256,8 @@ go_library(
"//pkg/sentry/hostcpu",
"//pkg/sentry/inet",
"//pkg/sentry/kernel/auth",
- "//pkg/sentry/kernel/epoll",
"//pkg/sentry/kernel/futex",
+ "//pkg/sentry/kernel/ipc",
"//pkg/sentry/kernel/mq",
"//pkg/sentry/kernel/msgqueue",
"//pkg/sentry/kernel/sched",
diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go
index 6006c46a9..8d0a21baf 100644
--- a/pkg/sentry/kernel/epoll/epoll.go
+++ b/pkg/sentry/kernel/epoll/epoll.go
@@ -66,7 +66,7 @@ type pollEntry struct {
file *refs.WeakRef `state:"manual"`
id FileIdentifier `state:"wait"`
userData [2]int32
- waiter waiter.Entry `state:"manual"`
+ waiter waiter.Entry
mask waiter.EventMask
flags EntryFlags
@@ -102,7 +102,7 @@ type EventPoll struct {
// Wait queue is used to notify interested parties when the event poll
// object itself becomes readable or writable.
- waiter.Queue `state:"zerovalue"`
+ waiter.Queue
// files is the map of all the files currently being observed, it is
// protected by mu.
@@ -454,14 +454,3 @@ func (e *EventPoll) RemoveEntry(ctx context.Context, id FileIdentifier) error {
return nil
}
-
-// UnregisterEpollWaiters removes the epoll waiter objects from the waiting
-// queues. This is different from Release() as the file is not dereferenced.
-func (e *EventPoll) UnregisterEpollWaiters() {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- for _, entry := range e.files {
- entry.id.File.EventUnregister(&entry.waiter)
- }
-}
diff --git a/pkg/sentry/kernel/epoll/epoll_state.go b/pkg/sentry/kernel/epoll/epoll_state.go
index e08d6287f..135a6d72c 100644
--- a/pkg/sentry/kernel/epoll/epoll_state.go
+++ b/pkg/sentry/kernel/epoll/epoll_state.go
@@ -21,9 +21,7 @@ import (
// afterLoad is invoked by stateify.
func (p *pollEntry) afterLoad() {
- p.waiter.Callback = p
p.file = refs.NewWeakRef(p.id.File, p)
- p.id.File.EventRegister(&p.waiter, p.mask)
}
// afterLoad is invoked by stateify.
diff --git a/pkg/sentry/kernel/eventfd/eventfd.go b/pkg/sentry/kernel/eventfd/eventfd.go
index 5ea44a2c2..bf625dede 100644
--- a/pkg/sentry/kernel/eventfd/eventfd.go
+++ b/pkg/sentry/kernel/eventfd/eventfd.go
@@ -54,7 +54,7 @@ type EventOperations struct {
// Queue is used to notify interested parties when the event object
// becomes readable or writable.
- wq waiter.Queue `state:"zerovalue"`
+ wq waiter.Queue
// val is the current value of the event counter.
val uint64
diff --git a/pkg/sentry/kernel/ipc/BUILD b/pkg/sentry/kernel/ipc/BUILD
index a5cbb2b51..bb5cf1c17 100644
--- a/pkg/sentry/kernel/ipc/BUILD
+++ b/pkg/sentry/kernel/ipc/BUILD
@@ -5,9 +5,9 @@ package(licenses = ["notice"])
go_library(
name = "ipc",
srcs = [
+ "ns.go",
"object.go",
"registry.go",
- "ns.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 04b24369a..d4851ccda 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -57,7 +57,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/hostcpu"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/sentry/kernel/epoll"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/sentry/kernel/ipc"
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
@@ -79,11 +78,19 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
-// VFS2Enabled is set to true when VFS2 is enabled. Added as a global for allow
-// easy access everywhere. To be removed once VFS2 becomes the default.
+// VFS2Enabled is set to true when VFS2 is enabled. Added as a global to allow
+// easy access everywhere.
+//
+// TODO(gvisor.dev/issue/1624): Remove when VFS1 is no longer used.
var VFS2Enabled = false
-// FUSEEnabled is set to true when FUSE is enabled. Added as a global for allow
+// LISAFSEnabled is set to true when lisafs protocol is enabled. Added as a
+// global to allow easy access everywhere.
+//
+// TODO(gvisor.dev/issue/6319): Remove when lisafs is default.
+var LISAFSEnabled = false
+
+// FUSEEnabled is set to true when FUSE is enabled. Added as a global to allow
// easy access everywhere. To be removed once FUSE is completed.
var FUSEEnabled = false
@@ -484,11 +491,6 @@ func (k *Kernel) SaveTo(ctx context.Context, w wire.Writer) error {
return err
}
- // Remove all epoll waiter objects from underlying wait queues.
- // NOTE: for programs to resume execution in future snapshot scenarios,
- // we will need to re-establish these waiter objects after saving.
- k.tasks.unregisterEpollWaiters(ctx)
-
// Clear the dirent cache before saving because Dirents must be Loaded in a
// particular order (parents before children), and Loading dirents from a cache
// breaks that order.
@@ -621,32 +623,6 @@ func (k *Kernel) flushWritesToFiles(ctx context.Context) error {
})
}
-// Preconditions: !VFS2Enabled.
-func (ts *TaskSet) unregisterEpollWaiters(ctx context.Context) {
- ts.mu.RLock()
- defer ts.mu.RUnlock()
-
- // Tasks that belong to the same process could potentially point to the
- // same FDTable. So we retain a map of processed ones to avoid
- // processing the same FDTable multiple times.
- processed := make(map[*FDTable]struct{})
- for t := range ts.Root.tids {
- // We can skip locking Task.mu here since the kernel is paused.
- if t.fdTable == nil {
- continue
- }
- if _, ok := processed[t.fdTable]; ok {
- continue
- }
- t.fdTable.forEach(ctx, func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
- if e, ok := file.FileOperations.(*epoll.EventPoll); ok {
- e.UnregisterEpollWaiters()
- }
- })
- processed[t.fdTable] = struct{}{}
- }
-}
-
// Preconditions: The kernel must be paused.
func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error {
invalidated := make(map[*mm.MemoryManager]struct{})
diff --git a/pkg/sentry/kernel/mq/mq.go b/pkg/sentry/kernel/mq/mq.go
index a7c787081..50ca6d34a 100644
--- a/pkg/sentry/kernel/mq/mq.go
+++ b/pkg/sentry/kernel/mq/mq.go
@@ -40,8 +40,10 @@ const (
ReadWrite
)
+// MaxName is the maximum size for a queue name.
+const MaxName = 255
+
const (
- MaxName = 255 // Maximum size for a queue name.
maxPriority = linux.MQ_PRIO_MAX - 1 // Highest possible message priority.
maxQueuesDefault = linux.DFLT_QUEUESMAX // Default max number of queues.
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 86beee6fe..8345473f3 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -55,7 +55,7 @@ const (
//
// +stateify savable
type Pipe struct {
- waiter.Queue `state:"nosave"`
+ waiter.Queue
// isNamed indicates whether this is a named pipe.
//
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index 9a95bf44c..1ea3c1bf7 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -158,7 +158,7 @@ type Task struct {
// signalQueue is protected by the signalMutex. Note that the task does
// not implement all queue methods, specifically the readiness checks.
// The task only broadcast a notification on signal delivery.
- signalQueue waiter.Queue `state:"zerovalue"`
+ signalQueue waiter.Queue
// If groupStopPending is true, the task should participate in a group
// stop in the interrupt path.
@@ -511,9 +511,7 @@ type Task struct {
numaNodeMask uint64
// netns is the task's network namespace. netns is never nil.
- //
- // netns is protected by mu.
- netns *inet.Namespace
+ netns inet.NamespaceAtomicPtr
// If rseqPreempted is true, before the next call to p.Switch(),
// interrupt rseq critical regions as defined by rseqAddr and
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index e174913d1..69a3227f0 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -447,7 +447,7 @@ func (t *Task) Unshare(flags int32) error {
t.mu.Unlock()
return linuxerr.EPERM
}
- t.netns = inet.NewNamespace(t.netns)
+ t.netns.Store(inet.NewNamespace(t.netns.Load()))
}
if flags&linux.CLONE_NEWUTS != 0 {
if !haveCapSysAdmin {
diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go
index 8de08151a..f0c168ecc 100644
--- a/pkg/sentry/kernel/task_log.go
+++ b/pkg/sentry/kernel/task_log.go
@@ -191,9 +191,11 @@ const (
//
// Preconditions: The task's owning TaskSet.mu must be locked.
func (t *Task) updateInfoLocked() {
- // Use the task's TID in the root PID namespace for logging.
+ // Use the task's TID and PID in the root PID namespace for logging.
+ pid := t.tg.pidns.owner.Root.tgids[t.tg]
tid := t.tg.pidns.owner.Root.tids[t]
- t.logPrefix.Store(fmt.Sprintf("[% 4d] ", tid))
+ t.logPrefix.Store(fmt.Sprintf("[% 4d:% 4d] ", pid, tid))
+
t.rebuildTraceContext(tid)
}
@@ -249,5 +251,9 @@ func (t *Task) traceExecEvent(image *TaskImage) {
return
}
defer file.DecRef(t)
- trace.Logf(t.traceContext, traceCategory, "exec: %s", file.PathnameWithDeleted(t))
+
+ // traceExecEvent function may be called before the task goroutine
+ // starts, so we must use the async context.
+ name := file.PathnameWithDeleted(t.AsyncContext())
+ trace.Logf(t.traceContext, traceCategory, "exec: %s", name)
}
diff --git a/pkg/sentry/kernel/task_net.go b/pkg/sentry/kernel/task_net.go
index f7711232c..e31e2b2e8 100644
--- a/pkg/sentry/kernel/task_net.go
+++ b/pkg/sentry/kernel/task_net.go
@@ -20,9 +20,7 @@ import (
// IsNetworkNamespaced returns true if t is in a non-root network namespace.
func (t *Task) IsNetworkNamespaced() bool {
- t.mu.Lock()
- defer t.mu.Unlock()
- return !t.netns.IsRoot()
+ return !t.netns.Load().IsRoot()
}
// NetworkContext returns the network stack used by the task. NetworkContext
@@ -31,14 +29,10 @@ func (t *Task) IsNetworkNamespaced() bool {
// TODO(gvisor.dev/issue/1833): Migrate callers of this method to
// NetworkNamespace().
func (t *Task) NetworkContext() inet.Stack {
- t.mu.Lock()
- defer t.mu.Unlock()
- return t.netns.Stack()
+ return t.netns.Load().Stack()
}
// NetworkNamespace returns the network namespace observed by the task.
func (t *Task) NetworkNamespace() *inet.Namespace {
- t.mu.Lock()
- defer t.mu.Unlock()
- return t.netns
+ return t.netns.Load()
}
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index 217c6f531..4919dea7c 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -140,7 +140,6 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
allowedCPUMask: cfg.AllowedCPUMask.Copy(),
ioUsage: &usage.IO{},
niceness: cfg.Niceness,
- netns: cfg.NetworkNamespace,
utsns: cfg.UTSNamespace,
ipcns: cfg.IPCNamespace,
abstractSockets: cfg.AbstractSocketNamespace,
@@ -152,6 +151,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
containerID: cfg.ContainerID,
cgroups: make(map[Cgroup]struct{}),
}
+ t.netns.Store(cfg.NetworkNamespace)
t.creds.Store(cfg.Credentials)
t.endStopCond.L = &t.tg.signalHandlers.mu
t.ptraceTracer.Store((*Task)(nil))
diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go
index 77ad62445..e38b723ce 100644
--- a/pkg/sentry/kernel/threads.go
+++ b/pkg/sentry/kernel/threads.go
@@ -324,11 +324,7 @@ type threadGroupNode struct {
// eventQueue is notified whenever a event of interest to Task.Wait occurs
// in a child of this thread group, or a ptrace tracee of a task in this
// thread group. Events are defined in task_exit.go.
- //
- // Note that we cannot check and save this wait queue similarly to other
- // wait queues, as the queue will not be empty by the time of saving, due
- // to the wait sourced from Exec().
- eventQueue waiter.Queue `state:"nosave"`
+ eventQueue waiter.Queue
// leader is the thread group's leader, which is the oldest task in the
// thread group; usually the last task in the thread group to call
diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go
index 9e00c2cec..dc12ad357 100644
--- a/pkg/sentry/mm/syscalls.go
+++ b/pkg/sentry/mm/syscalls.go
@@ -89,7 +89,7 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (hostar
}
// Offset + length must not overflow.
if end := opts.Offset + opts.Length; end < opts.Offset {
- return 0, linuxerr.ENOMEM
+ return 0, linuxerr.EOVERFLOW
}
} else {
opts.Offset = 0
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index 8a490b3de..834d72408 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -1,13 +1,26 @@
load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
+go_template_instance(
+ name = "atomicptr_machine",
+ out = "atomicptr_machine_unsafe.go",
+ package = "kvm",
+ prefix = "machine",
+ template = "//pkg/sync/atomicptr:generic_atomicptr",
+ types = {
+ "Value": "machine",
+ },
+)
+
go_library(
name = "kvm",
srcs = [
"address_space.go",
"address_space_amd64.go",
"address_space_arm64.go",
+ "atomicptr_machine_unsafe.go",
"bluepill.go",
"bluepill_allocator.go",
"bluepill_amd64.go",
@@ -50,7 +63,6 @@ go_library(
"//pkg/procid",
"//pkg/ring0",
"//pkg/ring0/pagetables",
- "//pkg/safecopy",
"//pkg/seccomp",
"//pkg/sentry/arch",
"//pkg/sentry/arch/fpu",
@@ -58,6 +70,7 @@ go_library(
"//pkg/sentry/platform",
"//pkg/sentry/platform/interrupt",
"//pkg/sentry/time",
+ "//pkg/sighandling",
"//pkg/sync",
"@org_golang_x_sys//unix:go_default_library",
],
@@ -69,10 +82,17 @@ go_test(
"kvm_amd64_test.go",
"kvm_amd64_test.s",
"kvm_arm64_test.go",
+ "kvm_safecopy_test.go",
"kvm_test.go",
"virtual_map_test.go",
],
library = ":kvm",
+ # FIXME(gvisor.dev/issue/3374): Not working with all build systems.
+ nogo = False,
+ # cgo has to be disabled. We have seen libc that blocks all signals and
+ # calls mmap from pthread_create, but we use SIGSYS to trap mmap system
+ # calls.
+ pure = True,
tags = [
"manual",
"nogotsan",
@@ -81,8 +101,10 @@ go_test(
deps = [
"//pkg/abi/linux",
"//pkg/hostarch",
+ "//pkg/memutil",
"//pkg/ring0",
"//pkg/ring0/pagetables",
+ "//pkg/safecopy",
"//pkg/sentry/arch",
"//pkg/sentry/arch/fpu",
"//pkg/sentry/platform",
diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go
index bb9967b9f..5be2215ed 100644
--- a/pkg/sentry/platform/kvm/bluepill.go
+++ b/pkg/sentry/platform/kvm/bluepill.go
@@ -19,8 +19,8 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/ring0"
- "gvisor.dev/gvisor/pkg/safecopy"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sighandling"
)
// bluepill enters guest mode.
@@ -61,6 +61,9 @@ var (
// This is called by bluepillHandler.
savedHandler uintptr
+ // savedSigsysHandler is a pointer to the previos handler of the SIGSYS signals.
+ savedSigsysHandler uintptr
+
// dieTrampolineAddr is the address of dieTrampoline.
dieTrampolineAddr uintptr
)
@@ -94,7 +97,7 @@ func (c *vCPU) die(context *arch.SignalContext64, msg string) {
func init() {
// Install the handler.
- if err := safecopy.ReplaceSignalHandler(bluepillSignal, addrOfSighandler(), &savedHandler); err != nil {
+ if err := sighandling.ReplaceSignalHandler(bluepillSignal, addrOfSighandler(), &savedHandler); err != nil {
panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err))
}
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go
index 0567c8d32..b2db2bb9f 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.go
@@ -71,10 +71,6 @@ func (c *vCPU) KernelSyscall() {
if regs.Rax != ^uint64(0) {
regs.Rip -= 2 // Rewind.
}
- // We only trigger a bluepill entry in the bluepill function, and can
- // therefore be guaranteed that there is no floating point state to be
- // loaded on resuming from halt. We only worry about saving on exit.
- ring0.SaveFloatingPoint(c.floatingPointState.BytePointer()) // escapes: no.
// N.B. Since KernelSyscall is called when the kernel makes a syscall,
// FS_BASE is already set for correct execution of this function.
//
@@ -112,8 +108,6 @@ func (c *vCPU) KernelException(vector ring0.Vector) {
regs.Rip = 0
}
// See above.
- ring0.SaveFloatingPoint(c.floatingPointState.BytePointer()) // escapes: no.
- // See above.
ring0.HaltAndWriteFSBase(regs) // escapes: no, reload host segment.
}
@@ -144,5 +138,5 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) {
// Set the context pointer to the saved floating point state. This is
// where the guest data has been serialized, the kernel will restore
// from this new pointer value.
- context.Fpstate = uint64(uintptrValue(c.floatingPointState.BytePointer()))
+ context.Fpstate = uint64(uintptrValue(c.FloatingPointState().BytePointer())) // escapes: no.
}
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_amd64.s
index c2a1dca11..5d8358f64 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64.s
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.s
@@ -32,6 +32,8 @@
// This is checked as the source of the fault.
#define CLI $0xfa
+#define SYS_MMAP 9
+
// See bluepill.go.
TEXT ·bluepill(SB),NOSPLIT,$0
begin:
@@ -95,6 +97,31 @@ TEXT ·addrOfSighandler(SB), $0-8
MOVQ AX, ret+0(FP)
RET
+TEXT ·sigsysHandler(SB),NOSPLIT,$0
+ // Check if the signal is from the kernel.
+ MOVQ $1, CX
+ CMPL CX, 0x8(SI)
+ JNE fallback
+
+ MOVL CONTEXT_RAX(DX), CX
+ CMPL CX, $SYS_MMAP
+ JNE fallback
+ PUSHQ DX // First argument (context).
+ CALL ·seccompMmapHandler(SB) // Call the handler.
+ POPQ DX // Discard the argument.
+ RET
+fallback:
+ // Jump to the previous signal handler.
+ XORQ CX, CX
+ MOVQ ·savedSigsysHandler(SB), AX
+ JMP AX
+
+// func addrOfSighandler() uintptr
+TEXT ·addrOfSigsysHandler(SB), $0-8
+ MOVQ $·sigsysHandler(SB), AX
+ MOVQ AX, ret+0(FP)
+ RET
+
// dieTrampoline: see bluepill.go, bluepill_amd64_unsafe.go for documentation.
TEXT ·dieTrampoline(SB),NOSPLIT,$0
PUSHQ BX // First argument (vCPU).
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go
index acb0cb05f..df772d620 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.go
@@ -70,7 +70,7 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) {
lazyVfp := c.GetLazyVFP()
if lazyVfp != 0 {
- fpsimd := fpsimdPtr(c.floatingPointState.BytePointer())
+ fpsimd := fpsimdPtr(c.FloatingPointState().BytePointer()) // escapes: no
context.Fpsimd64.Fpsr = fpsimd.Fpsr
context.Fpsimd64.Fpcr = fpsimd.Fpcr
context.Fpsimd64.Vregs = fpsimd.Vregs
@@ -90,12 +90,12 @@ func (c *vCPU) KernelSyscall() {
fpDisableTrap := ring0.CPACREL1()
if fpDisableTrap != 0 {
- fpsimd := fpsimdPtr(c.floatingPointState.BytePointer())
+ fpsimd := fpsimdPtr(c.FloatingPointState().BytePointer()) // escapes: no
fpcr := ring0.GetFPCR()
fpsr := ring0.GetFPSR()
fpsimd.Fpcr = uint32(fpcr)
fpsimd.Fpsr = uint32(fpsr)
- ring0.SaveVRegs(c.floatingPointState.BytePointer())
+ ring0.SaveVRegs(c.FloatingPointState().BytePointer()) // escapes: no
}
ring0.Halt()
@@ -114,12 +114,12 @@ func (c *vCPU) KernelException(vector ring0.Vector) {
fpDisableTrap := ring0.CPACREL1()
if fpDisableTrap != 0 {
- fpsimd := fpsimdPtr(c.floatingPointState.BytePointer())
+ fpsimd := fpsimdPtr(c.FloatingPointState().BytePointer()) // escapes: no
fpcr := ring0.GetFPCR()
fpsr := ring0.GetFPSR()
fpsimd.Fpcr = uint32(fpcr)
fpsimd.Fpsr = uint32(fpsr)
- ring0.SaveVRegs(c.floatingPointState.BytePointer())
+ ring0.SaveVRegs(c.FloatingPointState().BytePointer()) // escapes: no
}
ring0.Halt()
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s
index 308f2a951..9690e3772 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.s
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.s
@@ -29,9 +29,12 @@
// Only limited use of the context is done in the assembly stub below, most is
// done in the Go handlers.
#define SIGINFO_SIGNO 0x0
+#define SIGINFO_CODE 0x8
#define CONTEXT_PC 0x1B8
#define CONTEXT_R0 0xB8
+#define SYS_MMAP 222
+
// getTLS returns the value of TPIDR_EL0 register.
TEXT ·getTLS(SB),NOSPLIT,$0-8
MRS TPIDR_EL0, R1
@@ -98,6 +101,37 @@ TEXT ·addrOfSighandler(SB), $0-8
MOVD R0, ret+0(FP)
RET
+// The arguments are the following:
+//
+// R0 - The signal number.
+// R1 - Pointer to siginfo_t structure.
+// R2 - Pointer to ucontext structure.
+//
+TEXT ·sigsysHandler(SB),NOSPLIT,$0
+ // si_code should be SYS_SECCOMP.
+ MOVD SIGINFO_CODE(R1), R7
+ CMPW $1, R7
+ BNE fallback
+
+ CMPW $SYS_MMAP, R8
+ BNE fallback
+
+ MOVD R2, 8(RSP)
+ BL ·seccompMmapHandler(SB) // Call the handler.
+
+ RET
+
+fallback:
+ // Jump to the previous signal handler.
+ MOVD ·savedHandler(SB), R7
+ B (R7)
+
+// func addrOfSighandler() uintptr
+TEXT ·addrOfSigsysHandler(SB), $0-8
+ MOVD $·sigsysHandler(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
// dieTrampoline: see bluepill.go, bluepill_arm64_unsafe.go for documentation.
TEXT ·dieTrampoline(SB),NOSPLIT,$0
// R0: Fake the old PC as caller
diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go
index 0f0c1e73b..e38ca05c0 100644
--- a/pkg/sentry/platform/kvm/bluepill_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go
@@ -193,36 +193,8 @@ func bluepillHandler(context unsafe.Pointer) {
return
}
- // Increment the fault count.
- atomic.AddUint32(&c.faults, 1)
-
- // For MMIO, the physical address is the first data item.
- physical = uintptr(c.runData.data[0])
- virtual, ok := handleBluepillFault(c.machine, physical, physicalRegions, _KVM_MEM_FLAGS_NONE)
- if !ok {
- c.die(bluepillArchContext(context), "invalid physical address")
- return
- }
-
- // We now need to fill in the data appropriately. KVM
- // expects us to provide the result of the given MMIO
- // operation in the runData struct. This is safe
- // because, if a fault occurs here, the same fault
- // would have occurred in guest mode. The kernel should
- // not create invalid page table mappings.
- data := (*[8]byte)(unsafe.Pointer(&c.runData.data[1]))
- length := (uintptr)((uint32)(c.runData.data[2]))
- write := (uint8)(((c.runData.data[2] >> 32) & 0xff)) != 0
- for i := uintptr(0); i < length; i++ {
- b := bytePtr(uintptr(virtual) + i)
- if write {
- // Write to the given address.
- *b = data[i]
- } else {
- // Read from the given address.
- data[i] = *b
- }
- }
+ c.die(bluepillArchContext(context), "exit_mmio")
+ return
case _KVM_EXIT_IRQ_WINDOW_OPEN:
bluepillStopGuest(c)
case _KVM_EXIT_SHUTDOWN:
diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go
index aac0fdffe..ad6863646 100644
--- a/pkg/sentry/platform/kvm/kvm.go
+++ b/pkg/sentry/platform/kvm/kvm.go
@@ -77,7 +77,11 @@ var (
// OpenDevice opens the KVM device at /dev/kvm and returns the File.
func OpenDevice() (*os.File, error) {
- f, err := os.OpenFile("/dev/kvm", unix.O_RDWR, 0)
+ dev, ok := os.LookupEnv("GVISOR_KVM_DEV")
+ if !ok {
+ dev = "/dev/kvm"
+ }
+ f, err := os.OpenFile(dev, unix.O_RDWR, 0)
if err != nil {
return nil, fmt.Errorf("error opening /dev/kvm: %v", err)
}
diff --git a/pkg/sentry/platform/kvm/kvm_safecopy_test.go b/pkg/sentry/platform/kvm/kvm_safecopy_test.go
new file mode 100644
index 000000000..fe488e707
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm_safecopy_test.go
@@ -0,0 +1,104 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// FIXME(gvisor.dev/issue/6629): These tests don't pass on ARM64.
+//
+//go:build amd64
+// +build amd64
+
+package kvm
+
+import (
+ "fmt"
+ "os"
+ "testing"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/memutil"
+ "gvisor.dev/gvisor/pkg/safecopy"
+)
+
+func testSafecopy(t *testing.T, mapSize uintptr, fileSize uintptr, testFunc func(t *testing.T, c *vCPU, addr uintptr)) {
+ memfd, err := memutil.CreateMemFD(fmt.Sprintf("kvm_test_%d", os.Getpid()), 0)
+ if err != nil {
+ t.Errorf("error creating memfd: %v", err)
+ }
+
+ memfile := os.NewFile(uintptr(memfd), "kvm_test")
+ memfile.Truncate(int64(fileSize))
+ kvmTest(t, nil, func(c *vCPU) bool {
+ const n = 10
+ mappings := make([]uintptr, n)
+ defer func() {
+ for i := 0; i < n && mappings[i] != 0; i++ {
+ unix.RawSyscall(
+ unix.SYS_MUNMAP,
+ mappings[i], mapSize, 0)
+ }
+ }()
+ for i := 0; i < n; i++ {
+ addr, _, errno := unix.RawSyscall6(
+ unix.SYS_MMAP,
+ 0,
+ mapSize,
+ unix.PROT_READ|unix.PROT_WRITE,
+ unix.MAP_SHARED|unix.MAP_FILE,
+ uintptr(memfile.Fd()),
+ 0)
+ if errno != 0 {
+ t.Errorf("error mapping file: %v", errno)
+ }
+ mappings[i] = addr
+ testFunc(t, c, addr)
+ }
+ return false
+ })
+}
+
+func TestSafecopySigbus(t *testing.T) {
+ mapSize := uintptr(faultBlockSize)
+ fileSize := mapSize - hostarch.PageSize
+ buf := make([]byte, hostarch.PageSize)
+ testSafecopy(t, mapSize, fileSize, func(t *testing.T, c *vCPU, addr uintptr) {
+ want := safecopy.BusError{addr + fileSize}
+ bluepill(c)
+ _, err := safecopy.CopyIn(buf, unsafe.Pointer(addr+fileSize))
+ if err != want {
+ t.Errorf("expected error: got %v, want %v", err, want)
+ }
+ })
+}
+
+func TestSafecopy(t *testing.T) {
+ mapSize := uintptr(faultBlockSize)
+ fileSize := mapSize
+ testSafecopy(t, mapSize, fileSize, func(t *testing.T, c *vCPU, addr uintptr) {
+ want := uint32(0x12345678)
+ bluepill(c)
+ _, err := safecopy.SwapUint32(unsafe.Pointer(addr+fileSize-8), want)
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+ bluepill(c)
+ val, err := safecopy.LoadUint32(unsafe.Pointer(addr + fileSize - 8))
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+ if val != want {
+ t.Errorf("incorrect value: got %x, want %x", val, want)
+ }
+ })
+}
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index d67563958..f1f7e4ea4 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -17,16 +17,20 @@ package kvm
import (
"fmt"
"runtime"
+ gosync "sync"
"sync/atomic"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/procid"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
+ "gvisor.dev/gvisor/pkg/seccomp"
ktime "gvisor.dev/gvisor/pkg/sentry/time"
+ "gvisor.dev/gvisor/pkg/sighandling"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -35,6 +39,9 @@ type machine struct {
// fd is the vm fd.
fd int
+ // machinePoolIndex is the index in the machinePool array.
+ machinePoolIndex uint32
+
// nextSlot is the next slot for setMemoryRegion.
//
// This must be accessed atomically. If nextSlot is ^uint32(0), then
@@ -192,6 +199,10 @@ func (m *machine) newVCPU() *vCPU {
return c // Done.
}
+// readOnlyGuestRegions contains regions that have to be mapped read-only into
+// the guest physical address space. Right now, it is used on arm64 only.
+var readOnlyGuestRegions []region
+
// newMachine returns a new VM context.
func newMachine(vm int) (*machine, error) {
// Create the machine.
@@ -227,6 +238,10 @@ func newMachine(vm int) (*machine, error) {
m.upperSharedPageTables.MarkReadOnlyShared()
m.kernel.PageTables = pagetables.NewWithUpper(newAllocator(), m.upperSharedPageTables, ring0.KernelStartAddress)
+ // Install seccomp rules to trap runtime mmap system calls. They will
+ // be handled by seccompMmapHandler.
+ seccompMmapRules(m)
+
// Apply the physical mappings. Note that these mappings may point to
// guest physical addresses that are not actually available. These
// physical pages are mapped on demand, see kernel_unsafe.go.
@@ -241,32 +256,11 @@ func newMachine(vm int) (*machine, error) {
return true // Keep iterating.
})
- var physicalRegionsReadOnly []physicalRegion
- var physicalRegionsAvailable []physicalRegion
-
- physicalRegionsReadOnly = rdonlyRegionsForSetMem()
- physicalRegionsAvailable = availableRegionsForSetMem()
-
- // Map all read-only regions.
- for _, r := range physicalRegionsReadOnly {
- m.mapPhysical(r.physical, r.length, physicalRegionsReadOnly, _KVM_MEM_READONLY)
- }
-
// Ensure that the currently mapped virtual regions are actually
// available in the VM. Note that this doesn't guarantee no future
// faults, however it should guarantee that everything is available to
// ensure successful vCPU entry.
- applyVirtualRegions(func(vr virtualRegion) {
- if excludeVirtualRegion(vr) {
- return // skip region.
- }
-
- for _, r := range physicalRegionsReadOnly {
- if vr.virtual == r.virtual {
- return
- }
- }
-
+ mapRegion := func(vr region, flags uint32) {
for virtual := vr.virtual; virtual < vr.virtual+vr.length; {
physical, length, ok := translateToPhysical(virtual)
if !ok {
@@ -280,9 +274,32 @@ func newMachine(vm int) (*machine, error) {
}
// Ensure the physical range is mapped.
- m.mapPhysical(physical, length, physicalRegionsAvailable, _KVM_MEM_FLAGS_NONE)
+ m.mapPhysical(physical, length, physicalRegions, flags)
virtual += length
}
+ }
+
+ for _, vr := range readOnlyGuestRegions {
+ mapRegion(vr, _KVM_MEM_READONLY)
+ }
+
+ applyVirtualRegions(func(vr virtualRegion) {
+ if excludeVirtualRegion(vr) {
+ return // skip region.
+ }
+ for _, r := range readOnlyGuestRegions {
+ if vr.virtual == r.virtual {
+ return
+ }
+ }
+ // Take into account that the stack can grow down.
+ if vr.filename == "[stack]" {
+ vr.virtual -= 1 << 20
+ vr.length += 1 << 20
+ }
+
+ mapRegion(vr.region, 0)
+
})
// Initialize architecture state.
@@ -352,6 +369,10 @@ func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalReg
func (m *machine) Destroy() {
runtime.SetFinalizer(m, nil)
+ machinePoolMu.Lock()
+ machinePool[m.machinePoolIndex].Store(nil)
+ machinePoolMu.Unlock()
+
// Destroy vCPUs.
for _, c := range m.vCPUsByID {
if c == nil {
@@ -683,3 +704,72 @@ func (c *vCPU) setSystemTimeLegacy() error {
}
}
}
+
+const machinePoolSize = 16
+
+// machinePool is enumerated from the seccompMmapHandler signal handler
+var (
+ machinePool [machinePoolSize]machineAtomicPtr
+ machinePoolLen uint32
+ machinePoolMu sync.Mutex
+ seccompMmapRulesOnce gosync.Once
+)
+
+func sigsysHandler()
+func addrOfSigsysHandler() uintptr
+
+// seccompMmapRules adds seccomp rules to trap mmap system calls that will be
+// handled in seccompMmapHandler.
+func seccompMmapRules(m *machine) {
+ seccompMmapRulesOnce.Do(func() {
+ // Install the handler.
+ if err := sighandling.ReplaceSignalHandler(unix.SIGSYS, addrOfSigsysHandler(), &savedSigsysHandler); err != nil {
+ panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err))
+ }
+ rules := []seccomp.RuleSet{}
+ rules = append(rules, []seccomp.RuleSet{
+ // Trap mmap system calls and handle them in sigsysGoHandler
+ {
+ Rules: seccomp.SyscallRules{
+ unix.SYS_MMAP: {
+ {
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ seccomp.MatchAny{},
+ /* MAP_DENYWRITE is ignored and used only for filtering. */
+ seccomp.MaskedEqual(unix.MAP_DENYWRITE, 0),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_TRAP,
+ },
+ }...)
+ instrs, err := seccomp.BuildProgram(rules, linux.SECCOMP_RET_ALLOW, linux.SECCOMP_RET_ALLOW)
+ if err != nil {
+ panic(fmt.Sprintf("failed to build rules: %v", err))
+ }
+ // Perform the actual installation.
+ if err := seccomp.SetFilter(instrs); err != nil {
+ panic(fmt.Sprintf("failed to set filter: %v", err))
+ }
+ })
+
+ machinePoolMu.Lock()
+ n := atomic.LoadUint32(&machinePoolLen)
+ i := uint32(0)
+ for ; i < n; i++ {
+ if machinePool[i].Load() == nil {
+ break
+ }
+ }
+ if i == n {
+ if i == machinePoolSize {
+ machinePoolMu.Unlock()
+ panic("machinePool is full")
+ }
+ atomic.AddUint32(&machinePoolLen, 1)
+ }
+ machinePool[i].Store(m)
+ m.machinePoolIndex = i
+ machinePoolMu.Unlock()
+}
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index a96634381..5bc023899 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -29,7 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
- "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
ktime "gvisor.dev/gvisor/pkg/sentry/time"
)
@@ -72,10 +71,6 @@ type vCPUArchState struct {
//
// This starts above fixedKernelPCID.
PCIDs *pagetables.PCIDs
-
- // floatingPointState is the floating point state buffer used in guest
- // to host transitions. See usage in bluepill_amd64.go.
- floatingPointState fpu.State
}
const (
@@ -152,12 +147,6 @@ func (c *vCPU) initArchState() error {
return fmt.Errorf("error setting user registers: %v", errno)
}
- // Allocate some floating point state save area for the local vCPU.
- // This will be saved prior to leaving the guest, and we restore from
- // this always. We cannot use the pointer in the context alone because
- // we don't know how large the area there is in reality.
- c.floatingPointState = fpu.NewState()
-
// Set the time offset to the host native time.
return c.setSystemTime()
}
@@ -309,22 +298,6 @@ func loadByte(ptr *byte) byte {
return *ptr
}
-// prefaultFloatingPointState touches each page of the floating point state to
-// be sure that its physical pages are mapped.
-//
-// Otherwise the kernel can trigger KVM_EXIT_MMIO and an instruction that
-// triggered a fault will be emulated by the kvm kernel code, but it can't
-// emulate instructions like xsave and xrstor.
-//
-//go:nosplit
-func prefaultFloatingPointState(data *fpu.State) {
- size := len(*data)
- for i := 0; i < size; i += hostarch.PageSize {
- loadByte(&(*data)[i])
- }
- loadByte(&(*data)[size-1])
-}
-
// SwitchToUser unpacks architectural-details.
func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo) (hostarch.AccessType, error) {
// Check for canonical addresses.
@@ -355,11 +328,6 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo)
// allocations occur.
entersyscall()
bluepill(c)
- // The root table physical page has to be mapped to not fault in iret
- // or sysret after switching into a user address space. sysret and
- // iret are in the upper half that is global and already mapped.
- switchOpts.PageTables.PrefaultRootTable()
- prefaultFloatingPointState(switchOpts.FloatingPointState)
vector = c.CPU.SwitchToUser(switchOpts)
exitsyscall()
@@ -522,3 +490,7 @@ func (m *machine) getNewVCPU() *vCPU {
}
return nil
}
+
+func archPhysicalRegions(physicalRegions []physicalRegion) []physicalRegion {
+ return physicalRegions
+}
diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
index de798bb2c..fbacea9ad 100644
--- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
@@ -161,3 +161,15 @@ func (c *vCPU) getSystemRegisters(sregs *systemRegs) unix.Errno {
}
return 0
}
+
+//go:nosplit
+func seccompMmapSyscall(context unsafe.Pointer) (uintptr, uintptr, unix.Errno) {
+ ctx := bluepillArchContext(context)
+
+ // MAP_DENYWRITE is deprecated and ignored by kernel. We use it only for seccomp filters.
+ addr, _, e := unix.RawSyscall6(uintptr(ctx.Rax), uintptr(ctx.Rdi), uintptr(ctx.Rsi),
+ uintptr(ctx.Rdx), uintptr(ctx.R10)|unix.MAP_DENYWRITE, uintptr(ctx.R8), uintptr(ctx.R9))
+ ctx.Rax = uint64(addr)
+
+ return addr, uintptr(ctx.Rsi), e
+}
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
index 7937a8481..31998a600 100644
--- a/pkg/sentry/platform/kvm/machine_arm64.go
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
- "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
)
@@ -40,10 +39,6 @@ type vCPUArchState struct {
//
// This starts above fixedKernelPCID.
PCIDs *pagetables.PCIDs
-
- // floatingPointState is the floating point state buffer used in guest
- // to host transitions. See usage in bluepill_arm64.go.
- floatingPointState fpu.State
}
const (
@@ -110,18 +105,128 @@ func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) {
return phyRegions
}
+// archPhysicalRegions fills readOnlyGuestRegions and allocates separate
+// physical regions form them.
+func archPhysicalRegions(physicalRegions []physicalRegion) []physicalRegion {
+ applyVirtualRegions(func(vr virtualRegion) {
+ if excludeVirtualRegion(vr) {
+ return // skip region.
+ }
+ if !vr.accessType.Write {
+ readOnlyGuestRegions = append(readOnlyGuestRegions, vr.region)
+ }
+ })
+
+ rdRegions := readOnlyGuestRegions[:]
+
+ // Add an unreachable region.
+ rdRegions = append(rdRegions, region{
+ virtual: 0xffffffffffffffff,
+ length: 0,
+ })
+
+ var regions []physicalRegion
+ addValidRegion := func(r *physicalRegion, virtual, length uintptr) {
+ if length == 0 {
+ return
+ }
+ regions = append(regions, physicalRegion{
+ region: region{
+ virtual: virtual,
+ length: length,
+ },
+ physical: r.physical + (virtual - r.virtual),
+ })
+ }
+ i := 0
+ for _, pr := range physicalRegions {
+ start := pr.virtual
+ end := pr.virtual + pr.length
+ for start < end {
+ rdRegion := rdRegions[i]
+ rdStart := rdRegion.virtual
+ rdEnd := rdRegion.virtual + rdRegion.length
+ if rdEnd <= start {
+ i++
+ continue
+ }
+ if rdStart > start {
+ newEnd := rdStart
+ if end < rdStart {
+ newEnd = end
+ }
+ addValidRegion(&pr, start, newEnd-start)
+ start = rdStart
+ continue
+ }
+ if rdEnd < end {
+ addValidRegion(&pr, start, rdEnd-start)
+ start = rdEnd
+ continue
+ }
+ addValidRegion(&pr, start, end-start)
+ start = end
+ }
+ }
+
+ return regions
+}
+
// Get all available physicalRegions.
-func availableRegionsForSetMem() (phyRegions []physicalRegion) {
- var excludeRegions []region
+func availableRegionsForSetMem() []physicalRegion {
+ var excludedRegions []region
applyVirtualRegions(func(vr virtualRegion) {
if !vr.accessType.Write {
- excludeRegions = append(excludeRegions, vr.region)
+ excludedRegions = append(excludedRegions, vr.region)
}
})
- phyRegions = computePhysicalRegions(excludeRegions)
+ // Add an unreachable region.
+ excludedRegions = append(excludedRegions, region{
+ virtual: 0xffffffffffffffff,
+ length: 0,
+ })
- return phyRegions
+ var regions []physicalRegion
+ addValidRegion := func(r *physicalRegion, virtual, length uintptr) {
+ if length == 0 {
+ return
+ }
+ regions = append(regions, physicalRegion{
+ region: region{
+ virtual: virtual,
+ length: length,
+ },
+ physical: r.physical + (virtual - r.virtual),
+ })
+ }
+ i := 0
+ for _, pr := range physicalRegions {
+ start := pr.virtual
+ end := pr.virtual + pr.length
+ for start < end {
+ er := excludedRegions[i]
+ excludeEnd := er.virtual + er.length
+ excludeStart := er.virtual
+ if excludeEnd < start {
+ i++
+ continue
+ }
+ if excludeStart < start {
+ start = excludeEnd
+ i++
+ continue
+ }
+ rend := excludeStart
+ if rend > end {
+ rend = end
+ }
+ addValidRegion(&pr, start, rend-start)
+ start = excludeEnd
+ }
+ }
+
+ return regions
}
// nonCanonical generates a canonical address return.
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 1a4a9ce7d..e73d5c544 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -28,7 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/ring0"
"gvisor.dev/gvisor/pkg/ring0/pagetables"
- "gvisor.dev/gvisor/pkg/sentry/arch/fpu"
"gvisor.dev/gvisor/pkg/sentry/platform"
ktime "gvisor.dev/gvisor/pkg/sentry/time"
)
@@ -159,8 +158,6 @@ func (c *vCPU) initArchState() error {
c.PCIDs = pagetables.NewPCIDs(fixedKernelPCID+1, poolPCIDs)
}
- c.floatingPointState = fpu.NewState()
-
return c.setSystemTime()
}
@@ -333,3 +330,15 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *linux.SignalInfo)
}
}
+
+//go:nosplit
+func seccompMmapSyscall(context unsafe.Pointer) (uintptr, uintptr, unix.Errno) {
+ ctx := bluepillArchContext(context)
+
+ // MAP_DENYWRITE is deprecated and ignored by kernel. We use it only for seccomp filters.
+ addr, _, e := unix.RawSyscall6(uintptr(ctx.Regs[8]), uintptr(ctx.Regs[0]), uintptr(ctx.Regs[1]),
+ uintptr(ctx.Regs[2]), uintptr(ctx.Regs[3])|unix.MAP_DENYWRITE, uintptr(ctx.Regs[4]), uintptr(ctx.Regs[5]))
+ ctx.Regs[0] = uint64(addr)
+
+ return addr, uintptr(ctx.Regs[1]), e
+}
diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go
index cc3a1253b..cf3a4e7c9 100644
--- a/pkg/sentry/platform/kvm/machine_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_unsafe.go
@@ -171,3 +171,46 @@ func (c *vCPU) setSignalMask() error {
return nil
}
+
+// seccompMmapHandler is a signal handler for runtime mmap system calls
+// that are trapped by seccomp.
+//
+// It executes the mmap syscall with specified arguments and maps a new region
+// to the guest.
+//
+//go:nosplit
+func seccompMmapHandler(context unsafe.Pointer) {
+ addr, length, errno := seccompMmapSyscall(context)
+ if errno != 0 {
+ return
+ }
+
+ for i := uint32(0); i < atomic.LoadUint32(&machinePoolLen); i++ {
+ m := machinePool[i].Load()
+ if m == nil {
+ continue
+ }
+
+ // Map the new region to the guest.
+ vr := region{
+ virtual: addr,
+ length: length,
+ }
+ for virtual := vr.virtual; virtual < vr.virtual+vr.length; {
+ physical, length, ok := translateToPhysical(virtual)
+ if !ok {
+ // This must be an invalid region that was
+ // knocked out by creation of the physical map.
+ return
+ }
+ if virtual+length > vr.virtual+vr.length {
+ // Cap the length to the end of the area.
+ length = vr.virtual + vr.length - virtual
+ }
+
+ // Ensure the physical range is mapped.
+ m.mapPhysical(physical, length, physicalRegions, _KVM_MEM_FLAGS_NONE)
+ virtual += length
+ }
+ }
+}
diff --git a/pkg/sentry/platform/kvm/physical_map.go b/pkg/sentry/platform/kvm/physical_map.go
index d812e6c26..9864d1258 100644
--- a/pkg/sentry/platform/kvm/physical_map.go
+++ b/pkg/sentry/platform/kvm/physical_map.go
@@ -168,6 +168,9 @@ func computePhysicalRegions(excludedRegions []region) (physicalRegions []physica
}
addValidRegion(lastExcludedEnd, ring0.MaximumUserAddress-lastExcludedEnd)
+ // Do arch-specific actions on physical regions.
+ physicalRegions = archPhysicalRegions(physicalRegions)
+
// Dump our all physical regions.
for _, r := range physicalRegions {
log.Infof("physicalRegion: virtual [%x,%x) => physical [%x,%x)",
diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go
index 6d0ba8252..346a10043 100644
--- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go
+++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go
@@ -30,8 +30,8 @@ import (
func TLSWorks() bool
// SetTestTarget sets the rip appropriately.
-func SetTestTarget(regs *arch.Registers, fn func()) {
- regs.Pc = uint64(reflect.ValueOf(fn).Pointer())
+func SetTestTarget(regs *arch.Registers, fn uintptr) {
+ regs.Pc = uint64(fn)
}
// SetTouchTarget sets rax appropriately.
diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
index 7348c29a5..42876245a 100644
--- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
+++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
@@ -28,6 +28,11 @@ TEXT ·Getpid(SB),NOSPLIT,$0
SVC
RET
+TEXT ·AddrOfGetpid(SB),NOSPLIT,$0-8
+ MOVD $·Getpid(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
TEXT ·Touch(SB),NOSPLIT,$0
start:
MOVD 0(R8), R1
@@ -35,21 +40,41 @@ start:
SVC
B start
+TEXT ·AddrOfTouch(SB),NOSPLIT,$0-8
+ MOVD $·Touch(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
TEXT ·HaltLoop(SB),NOSPLIT,$0
start:
HLT
B start
+TEXT ·AddOfHaltLoop(SB),NOSPLIT,$0-8
+ MOVD $·HaltLoop(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
// This function simulates a loop of syscall.
TEXT ·SyscallLoop(SB),NOSPLIT,$0
start:
SVC
B start
+TEXT ·AddrOfSyscallLoop(SB),NOSPLIT,$0-8
+ MOVD $·SyscallLoop(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
TEXT ·SpinLoop(SB),NOSPLIT,$0
start:
B start
+TEXT ·AddrOfSpinLoop(SB),NOSPLIT,$0-8
+ MOVD $·SpinLoop(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
TEXT ·TLSWorks(SB),NOSPLIT,$0-8
NO_LOCAL_POINTERS
MOVD $0x6789, R5
@@ -125,6 +150,11 @@ TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0
SVC
RET // never reached
+TEXT ·AddrOfTwiddleRegsSyscall(SB),NOSPLIT,$0-8
+ MOVD $·TwiddleRegsSyscall(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
+
TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0
TWIDDLE_REGS()
MSR R10, TPIDR_EL0
@@ -132,3 +162,8 @@ TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0
// Branch to Register branches unconditionally to an address in <Rn>.
JMP (R6) // <=> br x6, must fault
RET // never reached
+
+TEXT ·AddrOfTwiddleRegsFault(SB),NOSPLIT,$0-8
+ MOVD $·TwiddleRegsFault(SB), R0
+ MOVD R0, ret+0(FP)
+ RET
diff --git a/pkg/sentry/seccheck/BUILD b/pkg/sentry/seccheck/BUILD
index 943fa180d..35feb969f 100644
--- a/pkg/sentry/seccheck/BUILD
+++ b/pkg/sentry/seccheck/BUILD
@@ -8,6 +8,8 @@ go_fieldenum(
name = "seccheck_fieldenum",
srcs = [
"clone.go",
+ "execve.go",
+ "exit.go",
"task.go",
],
out = "seccheck_fieldenum.go",
@@ -29,6 +31,8 @@ go_library(
name = "seccheck",
srcs = [
"clone.go",
+ "execve.go",
+ "exit.go",
"seccheck.go",
"seccheck_fieldenum.go",
"seqatomic_checkerslice_unsafe.go",
diff --git a/pkg/sentry/seccheck/execve.go b/pkg/sentry/seccheck/execve.go
new file mode 100644
index 000000000..f36e0730e
--- /dev/null
+++ b/pkg/sentry/seccheck/execve.go
@@ -0,0 +1,65 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seccheck
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// ExecveInfo contains information used by the Execve checkpoint.
+//
+// +fieldenum Execve
+type ExecveInfo struct {
+ // Invoker identifies the invoking thread.
+ Invoker TaskInfo
+
+ // Credentials are the invoking thread's credentials.
+ Credentials *auth.Credentials
+
+ // BinaryPath is a path to the executable binary file being switched to in
+ // the mount namespace in which it was opened.
+ BinaryPath string
+
+ // Argv is the new process image's argument vector.
+ Argv []string
+
+ // Env is the new process image's environment variables.
+ Env []string
+
+ // BinaryMode is the executable binary file's mode.
+ BinaryMode uint16
+
+ // BinarySHA256 is the SHA-256 hash of the executable binary file.
+ //
+ // Note that this requires reading the entire file into memory, which is
+ // likely to be extremely slow.
+ BinarySHA256 [32]byte
+}
+
+// ExecveReq returns fields required by the Execve checkpoint.
+func (s *state) ExecveReq() ExecveFieldSet {
+ return s.execveReq.Load()
+}
+
+// Execve is called at the Execve checkpoint.
+func (s *state) Execve(ctx context.Context, mask ExecveFieldSet, info *ExecveInfo) error {
+ for _, c := range s.getCheckers() {
+ if err := c.Execve(ctx, mask, *info); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/seccheck/exit.go b/pkg/sentry/seccheck/exit.go
new file mode 100644
index 000000000..69cb6911c
--- /dev/null
+++ b/pkg/sentry/seccheck/exit.go
@@ -0,0 +1,57 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package seccheck
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+)
+
+// ExitNotifyParentInfo contains information used by the ExitNotifyParent
+// checkpoint.
+//
+// +fieldenum ExitNotifyParent
+type ExitNotifyParentInfo struct {
+ // Exiter identifies the exiting thread. Note that by the checkpoint's
+ // definition, Exiter.ThreadID == Exiter.ThreadGroupID and
+ // Exiter.ThreadStartTime == Exiter.ThreadGroupStartTime, so requesting
+ // ThreadGroup* fields is redundant.
+ Exiter TaskInfo
+
+ // ExitStatus is the exiting thread group's exit status, as reported
+ // by wait*().
+ ExitStatus linux.WaitStatus
+}
+
+// ExitNotifyParentReq returns fields required by the ExitNotifyParent
+// checkpoint.
+func (s *state) ExitNotifyParentReq() ExitNotifyParentFieldSet {
+ return s.exitNotifyParentReq.Load()
+}
+
+// ExitNotifyParent is called at the ExitNotifyParent checkpoint.
+//
+// The ExitNotifyParent checkpoint occurs when a zombied thread group leader,
+// not waiting for exit acknowledgement from a non-parent ptracer, becomes the
+// last non-dead thread in its thread group and notifies its parent of its
+// exiting.
+func (s *state) ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info *ExitNotifyParentInfo) error {
+ for _, c := range s.getCheckers() {
+ if err := c.ExitNotifyParent(ctx, mask, *info); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/seccheck/seccheck.go b/pkg/sentry/seccheck/seccheck.go
index b6c9d44ce..e13274096 100644
--- a/pkg/sentry/seccheck/seccheck.go
+++ b/pkg/sentry/seccheck/seccheck.go
@@ -29,6 +29,8 @@ type Point uint
// PointX represents the checkpoint X.
const (
PointClone Point = iota
+ PointExecve
+ PointExitNotifyParent
// Add new Points above this line.
pointLength
@@ -47,6 +49,8 @@ const (
// registered concurrently with invocations of checkpoints).
type Checker interface {
Clone(ctx context.Context, mask CloneFieldSet, info CloneInfo) error
+ Execve(ctx context.Context, mask ExecveFieldSet, info ExecveInfo) error
+ ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info ExitNotifyParentInfo) error
}
// CheckerDefaults may be embedded by implementations of Checker to obtain
@@ -58,6 +62,16 @@ func (CheckerDefaults) Clone(ctx context.Context, mask CloneFieldSet, info Clone
return nil
}
+// Execve implements Checker.Execve.
+func (CheckerDefaults) Execve(ctx context.Context, mask ExecveFieldSet, info ExecveInfo) error {
+ return nil
+}
+
+// ExitNotifyParent implements Checker.ExitNotifyParent.
+func (CheckerDefaults) ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info ExitNotifyParentInfo) error {
+ return nil
+}
+
// CheckerReq indicates what checkpoints a corresponding Checker runs at, and
// what information it requires at those checkpoints.
type CheckerReq struct {
@@ -69,7 +83,9 @@ type CheckerReq struct {
// All of the following fields indicate what fields in the corresponding
// XInfo struct will be requested at the corresponding checkpoint.
- Clone CloneFields
+ Clone CloneFields
+ Execve ExecveFields
+ ExitNotifyParent ExitNotifyParentFields
}
// Global is the method receiver of all seccheck functions.
@@ -101,7 +117,9 @@ type state struct {
// corresponding XInfo struct have been requested by any registered
// checker, are accessed using atomic memory operations, and are mutated
// with registrationMu locked.
- cloneReq CloneFieldSet
+ cloneReq CloneFieldSet
+ execveReq ExecveFieldSet
+ exitNotifyParentReq ExitNotifyParentFieldSet
}
// AppendChecker registers the given Checker to execute at checkpoints. The
@@ -110,7 +128,11 @@ type state struct {
func (s *state) AppendChecker(c Checker, req *CheckerReq) {
s.registrationMu.Lock()
defer s.registrationMu.Unlock()
+
s.cloneReq.AddFieldsLoadable(req.Clone)
+ s.execveReq.AddFieldsLoadable(req.Execve)
+ s.exitNotifyParentReq.AddFieldsLoadable(req.ExitNotifyParent)
+
s.appendCheckerLocked(c)
for _, p := range req.Points {
word, bit := p/32, p%32
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index 7ee89a735..00f925166 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -4,7 +4,10 @@ package(licenses = ["notice"])
go_library(
name = "socket",
- srcs = ["socket.go"],
+ srcs = [
+ "socket.go",
+ "socket_state.go",
+ ],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 00a5e729a..6077b2150 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -29,10 +29,9 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "time"
)
-const maxInt = int(^uint(0) >> 1)
-
// SCMCredentials represents a SCM_CREDENTIALS socket control message.
type SCMCredentials interface {
transport.CredentialsControlMessage
@@ -78,7 +77,7 @@ func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) {
}
// Files implements SCMRights.Files.
-func (fs *RightsFiles) Files(ctx context.Context, max int) (RightsFiles, bool) {
+func (fs *RightsFiles) Files(_ context.Context, max int) (RightsFiles, bool) {
n := max
var trunc bool
if l := len(*fs); n > l {
@@ -124,7 +123,7 @@ func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) ([]int32
break
}
- fds = append(fds, int32(fd))
+ fds = append(fds, fd)
}
return fds, trunc
}
@@ -300,8 +299,8 @@ func alignSlice(buf []byte, align uint) []byte {
}
// PackTimestamp packs a SO_TIMESTAMP socket control message.
-func PackTimestamp(t *kernel.Task, timestamp int64, buf []byte) []byte {
- timestampP := linux.NsecToTimeval(timestamp)
+func PackTimestamp(t *kernel.Task, timestamp time.Time, buf []byte) []byte {
+ timestampP := linux.NsecToTimeval(timestamp.UnixNano())
return putCmsgStruct(
buf,
linux.SOL_SOCKET,
@@ -355,6 +354,17 @@ func PackIPPacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPPacketIn
)
}
+// PackIPv6PacketInfo packs an IPV6_PKTINFO socket control message.
+func PackIPv6PacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPv6PacketInfo, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IPV6,
+ linux.IPV6_PKTINFO,
+ t.Arch().Width(),
+ packetInfo,
+ )
+}
+
// PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message.
func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, buf []byte) []byte {
var level uint32
@@ -412,6 +422,10 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt
buf = PackIPPacketInfo(t, &cmsgs.IP.PacketInfo, buf)
}
+ if cmsgs.IP.HasIPv6PacketInfo {
+ buf = PackIPv6PacketInfo(t, &cmsgs.IP.IPv6PacketInfo, buf)
+ }
+
if cmsgs.IP.OriginalDstAddress != nil {
buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf)
}
@@ -453,6 +467,10 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
space += cmsgSpace(t, linux.SizeOfControlMessageIPPacketInfo)
}
+ if cmsgs.IP.HasIPv6PacketInfo {
+ space += cmsgSpace(t, linux.SizeOfControlMessageIPv6PacketInfo)
+ }
+
if cmsgs.IP.OriginalDstAddress != nil {
space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes())
}
@@ -526,7 +544,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
}
var ts linux.Timeval
ts.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval])
- cmsgs.IP.Timestamp = ts.ToNsecCapped()
+ cmsgs.IP.Timestamp = ts.ToTime()
cmsgs.IP.HasTimestamp = true
i += bits.AlignUp(length, width)
diff --git a/pkg/sentry/socket/control/control_test.go b/pkg/sentry/socket/control/control_test.go
index 7e28a0cef..1b04e1bbc 100644
--- a/pkg/sentry/socket/control/control_test.go
+++ b/pkg/sentry/socket/control/control_test.go
@@ -50,7 +50,7 @@ func TestParse(t *testing.T) {
want := socket.ControlMessages{
IP: socket.IPControlMessages{
HasTimestamp: true,
- Timestamp: ts.ToNsecCapped(),
+ Timestamp: ts.ToTime(),
},
}
if diff := cmp.Diff(want, cmsg); diff != "" {
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 1c1e501ba..6e2318f75 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -111,7 +111,7 @@ func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
}
return readv(s.fd, safemem.IovecsFromBlockSeq(dsts))
}))
- return int64(n), err
+ return n, err
}
// Write implements fs.FileOperations.Write.
@@ -134,7 +134,7 @@ func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
}
return writev(s.fd, safemem.IovecsFromBlockSeq(srcs))
}))
- return int64(n), err
+ return n, err
}
// Socket implements socket.Provider.Socket.
@@ -180,7 +180,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto
}
// Pair implements socket.Provider.Pair.
-func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+func (p *socketProvider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
// Not supported by AF_INET/AF_INET6.
return nil, nil, nil
}
@@ -207,7 +207,7 @@ type socketOpsCommon struct {
// Release implements fs.FileOperations.Release.
func (s *socketOpsCommon) Release(context.Context) {
fdnotifier.RemoveFD(int32(s.fd))
- unix.Close(s.fd)
+ _ = unix.Close(s.fd)
}
// Readiness implements waiter.Waitable.Readiness.
@@ -218,13 +218,13 @@ func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
// EventRegister implements waiter.Waitable.EventRegister.
func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
s.queue.EventRegister(e, mask)
- fdnotifier.UpdateFD(int32(s.fd))
+ _ = fdnotifier.UpdateFD(int32(s.fd))
}
// EventUnregister implements waiter.Waitable.EventUnregister.
func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
s.queue.EventUnregister(e)
- fdnotifier.UpdateFD(int32(s.fd))
+ _ = fdnotifier.UpdateFD(int32(s.fd))
}
// Connect implements socket.Socket.Connect.
@@ -316,7 +316,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int,
if kernel.VFS2Enabled {
f, err := newVFS2Socket(t, s.family, s.stype, s.protocol, fd, uint32(flags&unix.SOCK_NONBLOCK))
if err != nil {
- unix.Close(fd)
+ _ = unix.Close(fd)
return 0, nil, 0, err
}
defer f.DecRef(t)
@@ -328,7 +328,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int,
} else {
f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&unix.SOCK_NONBLOCK != 0)
if err != nil {
- unix.Close(fd)
+ _ = unix.Close(fd)
return 0, nil, 0, err
}
defer f.DecRef(t)
@@ -343,7 +343,7 @@ func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int,
}
// Bind implements socket.Socket.Bind.
-func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+func (s *socketOpsCommon) Bind(_ *kernel.Task, sockaddr []byte) *syserr.Error {
if len(sockaddr) > sizeofSockaddr {
sockaddr = sockaddr[:sizeofSockaddr]
}
@@ -356,12 +356,12 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
// Listen implements socket.Socket.Listen.
-func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
+func (s *socketOpsCommon) Listen(_ *kernel.Task, backlog int) *syserr.Error {
return syserr.FromError(unix.Listen(s.fd, backlog))
}
// Shutdown implements socket.Socket.Shutdown.
-func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
+func (s *socketOpsCommon) Shutdown(_ *kernel.Task, how int) *syserr.Error {
switch how {
case unix.SHUT_RD, unix.SHUT_WR, unix.SHUT_RDWR:
return syserr.FromError(unix.Shutdown(s.fd, how))
@@ -371,7 +371,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
+func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, _ hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
if outLen < 0 {
return nil, syserr.ErrInvalidArgument
}
@@ -401,7 +401,7 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr
case linux.TCP_NODELAY:
optlen = sizeofInt32
case linux.TCP_INFO:
- optlen = int(linux.SizeOfTCPInfo)
+ optlen = linux.SizeOfTCPInfo
}
}
@@ -579,7 +579,7 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s
controlMessages.IP.HasTimestamp = true
ts := linux.Timeval{}
ts.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfTimeval])
- controlMessages.IP.Timestamp = ts.ToNsecCapped()
+ controlMessages.IP.Timestamp = ts.ToTime()
}
case linux.SOL_IP:
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index e3eade180..8d9e73243 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -58,8 +58,8 @@ var nameToID = map[string]stack.TableID{
// DefaultLinuxTables returns the rules of stack.DefaultTables() wrapped for
// compatibility with netfilter extensions.
-func DefaultLinuxTables(seed uint32) *stack.IPTables {
- tables := stack.DefaultTables(seed)
+func DefaultLinuxTables(seed uint32, clock tcpip.Clock) *stack.IPTables {
+ tables := stack.DefaultTables(seed, clock)
tables.VisitTargets(func(oldTarget stack.Target) stack.Target {
switch val := oldTarget.(type) {
case *stack.AcceptTarget:
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index ea56f39c1..b9c15daab 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -647,7 +647,7 @@ func (jt *JumpTarget) id() targetID {
}
// Action implements stack.Target.Action.
-func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) {
+func (jt *JumpTarget) Action(*stack.PacketBuffer, stack.Hook, *stack.Route, stack.AddressableEndpoint) (stack.RuleVerdict, int) {
return stack.RuleJump, jt.RuleNum
}
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index bf5ec4558..075f61cda 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"device.go",
"netstack.go",
+ "netstack_state.go",
"netstack_vfs2.go",
"provider.go",
"provider_vfs2.go",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index f79bda922..030c6c8e4 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -274,6 +274,7 @@ var Metrics = tcpip.Stats{
ChecksumErrors: mustCreateMetric("/netstack/tcp/checksum_errors", "Number of segments dropped due to bad checksums."),
FailedPortReservations: mustCreateMetric("/netstack/tcp/failed_port_reservations", "Number of time TCP failed to reserve a port."),
SegmentsAckedWithDSACK: mustCreateMetric("/netstack/tcp/segments_acked_with_dsack", "Number of segments for which DSACK was received."),
+ SpuriousRecovery: mustCreateMetric("/netstack/tcp/spurious_recovery", "Number of times the connection entered loss recovery spuriously."),
},
UDP: tcpip.UDPStats{
PacketsReceived: mustCreateMetric("/netstack/udp/packets_received", "Number of UDP datagrams received via HandlePacket."),
@@ -378,9 +379,9 @@ type socketOpsCommon struct {
// timestampValid indicates whether timestamp for SIOCGSTAMP has been
// set. It is protected by readMu.
timestampValid bool
- // timestampNS holds the timestamp to use with SIOCTSTAMP. It is only
+ // timestamp holds the timestamp to use with SIOCTSTAMP. It is only
// valid when timestampValid is true. It is protected by readMu.
- timestampNS int64
+ timestamp time.Time `state:".(int64)"`
// TODO(b/153685824): Move this to SocketOptions.
// sockOptInq corresponds to TCP_INQ.
@@ -410,15 +411,6 @@ var sockAddrInetSize = (*linux.SockAddrInet)(nil).SizeBytes()
var sockAddrInet6Size = (*linux.SockAddrInet6)(nil).SizeBytes()
var sockAddrLinkSize = (*linux.SockAddrLink)(nil).SizeBytes()
-// bytesToIPAddress converts an IPv4 or IPv6 address from the user to the
-// netstack representation taking any addresses into account.
-func bytesToIPAddress(addr []byte) tcpip.Address {
- if bytes.Equal(addr, make([]byte, 4)) || bytes.Equal(addr, make([]byte, 16)) {
- return ""
- }
- return tcpip.Address(addr)
-}
-
// minSockAddrLen returns the minimum length in bytes of a socket address for
// the socket's family.
func (s *socketOpsCommon) minSockAddrLen() int {
@@ -468,7 +460,7 @@ func (s *socketOpsCommon) Release(ctx context.Context) {
t := kernel.TaskFromContext(ctx)
start := t.Kernel().MonotonicClock().Now()
deadline := start.Add(v.Timeout)
- t.BlockWithDeadline(ch, true, deadline)
+ _ = t.BlockWithDeadline(ch, true, deadline)
}
}
@@ -488,7 +480,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
}
// WriteTo implements fs.FileOperations.WriteTo.
-func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) {
+func (s *SocketOperations) WriteTo(_ context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) {
s.readMu.Lock()
defer s.readMu.Unlock()
@@ -543,7 +535,7 @@ func (l *limitedPayloader) Len() int {
}
// ReadFrom implements fs.FileOperations.ReadFrom.
-func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
+func (s *SocketOperations) ReadFrom(_ context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
f := limitedPayloader{
inner: io.LimitedReader{
R: r,
@@ -654,7 +646,7 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool
// Bind implements the linux syscall bind(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+func (s *socketOpsCommon) Bind(_ *kernel.Task, sockaddr []byte) *syserr.Error {
if len(sockaddr) < 2 {
return syserr.ErrInvalidArgument
}
@@ -672,13 +664,10 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
a.UnmarshalBytes(sockaddr[:sockAddrLinkSize])
- if a.Protocol != uint16(s.protocol) {
- return syserr.ErrInvalidArgument
- }
-
addr = tcpip.FullAddress{
NIC: tcpip.NICID(a.InterfaceIndex),
Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ Port: socket.Ntohs(a.Protocol),
}
} else {
if s.minSockAddrLen() > len(sockaddr) {
@@ -717,7 +706,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
// Listen implements the linux syscall listen(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
+func (s *socketOpsCommon) Listen(_ *kernel.Task, backlog int) *syserr.Error {
return syserr.TranslateNetstackError(s.Endpoint.Listen(backlog))
}
@@ -808,7 +797,7 @@ func ConvertShutdown(how int) (tcpip.ShutdownFlags, *syserr.Error) {
// Shutdown implements the linux syscall shutdown(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
+func (s *socketOpsCommon) Shutdown(_ *kernel.Task, how int) *syserr.Error {
f, err := ConvertShutdown(how)
if err != nil {
return err
@@ -889,7 +878,7 @@ func boolToInt32(v bool) int32 {
}
// getSockOptSocket implements GetSockOpt when level is SOL_SOCKET.
-func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) {
+func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, _ linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/124056281): Stop rejecting short optLen values in getsockopt.
switch name {
case linux.SO_ERROR:
@@ -1374,6 +1363,14 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress()))
return &v, nil
+ case linux.IPV6_RECVPKTINFO:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetIPv6ReceivePacketInfo()))
+ return &v, nil
+
case linux.IP6T_ORIGINAL_DST:
if outLen < sockAddrInet6Size {
return nil, syserr.ErrInvalidArgument
@@ -1397,11 +1394,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
- info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, true)
+ info, err := netfilter.GetInfo(t, stk.(*Stack).Stack, outPtr, true)
if err != nil {
return nil, err
}
@@ -1417,11 +1414,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
- entries, err := netfilter.GetEntries6(t, stack.(*Stack).Stack, outPtr, outLen)
+ entries, err := netfilter.GetEntries6(t, stk.(*Stack).Stack, outPtr, outLen)
if err != nil {
return nil, err
}
@@ -1437,8 +1434,8 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
ret, err := netfilter.TargetRevision(t, outPtr, header.IPv6ProtocolNumber)
@@ -1454,7 +1451,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
// getSockOptIP implements GetSockOpt when level is SOL_IP.
-func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr hostarch.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) {
+func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr hostarch.Addr, outLen int, _ int) (marshal.Marshallable, *syserr.Error) {
if _, ok := ep.(tcpip.Endpoint); !ok {
log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
return nil, syserr.ErrUnknownProtocolOption
@@ -1594,11 +1591,11 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
- info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, false)
+ info, err := netfilter.GetInfo(t, stk.(*Stack).Stack, outPtr, false)
if err != nil {
return nil, err
}
@@ -1614,11 +1611,11 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
- entries, err := netfilter.GetEntries4(t, stack.(*Stack).Stack, outPtr, outLen)
+ entries, err := netfilter.GetEntries4(t, stk.(*Stack).Stack, outPtr, outLen)
if err != nil {
return nil, err
}
@@ -1634,8 +1631,8 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return nil, syserr.ErrNoDevice
}
ret, err := netfilter.TargetRevision(t, outPtr, header.IPv4ProtocolNumber)
@@ -2130,6 +2127,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0)
return nil
+ case linux.IPV6_RECVPKTINFO:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := int32(hostarch.ByteOrder.Uint32(optVal))
+
+ ep.SocketOptions().SetIPv6ReceivePacketInfo(v != 0)
+ return nil
+
case linux.IPV6_TCLASS:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -2172,12 +2178,12 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return syserr.ErrNoDevice
}
// Stack must be a netstack stack.
- return netfilter.SetEntries(t, stack.(*Stack).Stack, optVal, true)
+ return netfilter.SetEntries(t, stk.(*Stack).Stack, optVal, true)
case linux.IP6T_SO_SET_ADD_COUNTERS:
log.Infof("IP6T_SO_SET_ADD_COUNTERS is not supported")
@@ -2415,12 +2421,12 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return syserr.ErrProtocolNotAvailable
}
- stack := inet.StackFromContext(t)
- if stack == nil {
+ stk := inet.StackFromContext(t)
+ if stk == nil {
return syserr.ErrNoDevice
}
// Stack must be a netstack stack.
- return netfilter.SetEntries(t, stack.(*Stack).Stack, optVal, false)
+ return netfilter.SetEntries(t, stk.(*Stack).Stack, optVal, false)
case linux.IPT_SO_SET_ADD_COUNTERS:
log.Infof("IPT_SO_SET_ADD_COUNTERS is not supported")
@@ -2519,7 +2525,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) {
linux.IPV6_RECVHOPLIMIT,
linux.IPV6_RECVHOPOPTS,
linux.IPV6_RECVPATHMTU,
- linux.IPV6_RECVPKTINFO,
linux.IPV6_RECVRTHDR,
linux.IPV6_RTHDR,
linux.IPV6_RTHDRDSTOPTS,
@@ -2588,7 +2593,7 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) {
// GetSockName implements the linux syscall getsockname(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetSockName(*kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.Endpoint.GetLocalAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -2600,7 +2605,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *
// GetPeerName implements the linux syscall getpeername(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetPeerName(*kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.Endpoint.GetRemoteAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -2745,6 +2750,8 @@ func (s *socketOpsCommon) controlMessages(cm tcpip.ControlMessages) socket.Contr
TClass: readCM.TClass,
HasIPPacketInfo: readCM.HasIPPacketInfo,
PacketInfo: readCM.PacketInfo,
+ HasIPv6PacketInfo: readCM.HasIPv6PacketInfo,
+ IPv6PacketInfo: readCM.IPv6PacketInfo,
OriginalDstAddress: readCM.OriginalDstAddress,
SockErr: readCM.SockErr,
},
@@ -2759,7 +2766,7 @@ func (s *socketOpsCommon) updateTimestamp(cm tcpip.ControlMessages) {
// Save the SIOCGSTAMP timestamp only if SO_TIMESTAMP is disabled.
if !s.sockOptTimestamp {
s.timestampValid = true
- s.timestampNS = cm.Timestamp
+ s.timestamp = cm.Timestamp
}
}
@@ -2818,7 +2825,7 @@ func (s *socketOpsCommon) recvErr(t *kernel.Task, dst usermem.IOSequence) (int,
// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// tcpip.Endpoint.
-func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, _ uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
if flags&linux.MSG_ERRQUEUE != 0 {
return s.recvErr(t, dst)
}
@@ -2983,7 +2990,7 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy
return 0, linuxerr.ENOENT
}
- tv := linux.NsecToTimeval(s.timestampNS)
+ tv := linux.NsecToTimeval(s.timestamp.UnixNano())
_, err := tv.CopyOut(t, args[2].Pointer())
return 0, err
@@ -3090,7 +3097,7 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
}
// interfaceIoctl implements interface requests.
-func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFReq) *syserr.Error {
+func interfaceIoctl(ctx context.Context, _ usermem.IO, arg int, ifr *linux.IFReq) *syserr.Error {
var (
iface inet.Interface
index int32
@@ -3098,8 +3105,8 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
)
// Find the relevant device.
- stack := inet.StackFromContext(ctx)
- if stack == nil {
+ stk := inet.StackFromContext(ctx)
+ if stk == nil {
return syserr.ErrNoDevice
}
@@ -3109,7 +3116,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
// Gets the name of the interface given the interface index
// stored in ifr_ifindex.
index = int32(hostarch.ByteOrder.Uint32(ifr.Data[:4]))
- if iface, ok := stack.Interfaces()[index]; ok {
+ if iface, ok := stk.Interfaces()[index]; ok {
ifr.SetName(iface.Name)
return nil
}
@@ -3117,7 +3124,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
}
// Find the relevant device.
- for index, iface = range stack.Interfaces() {
+ for index, iface = range stk.Interfaces() {
if iface.Name == ifr.Name() {
found = true
break
@@ -3150,7 +3157,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
}
case linux.SIOCGIFFLAGS:
- f, err := interfaceStatusFlags(stack, iface.Name)
+ f, err := interfaceStatusFlags(stk, iface.Name)
if err != nil {
return err
}
@@ -3160,7 +3167,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
case linux.SIOCGIFADDR:
// Copy the IPv4 address out.
- for _, addr := range stack.InterfaceAddrs()[index] {
+ for _, addr := range stk.InterfaceAddrs()[index] {
// This ioctl is only compatible with AF_INET addresses.
if addr.Family != linux.AF_INET {
continue
@@ -3196,7 +3203,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
case linux.SIOCGIFNETMASK:
// Gets the network mask of a device.
- for _, addr := range stack.InterfaceAddrs()[index] {
+ for _, addr := range stk.InterfaceAddrs()[index] {
// This ioctl is only compatible with AF_INET addresses.
if addr.Family != linux.AF_INET {
continue
@@ -3228,24 +3235,24 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
}
// ifconfIoctl populates a struct ifconf for the SIOCGIFCONF ioctl.
-func ifconfIoctl(ctx context.Context, t *kernel.Task, io usermem.IO, ifc *linux.IFConf) error {
+func ifconfIoctl(ctx context.Context, t *kernel.Task, _ usermem.IO, ifc *linux.IFConf) error {
// If Ptr is NULL, return the necessary buffer size via Len.
// Otherwise, write up to Len bytes starting at Ptr containing ifreq
// structs.
- stack := inet.StackFromContext(ctx)
- if stack == nil {
+ stk := inet.StackFromContext(ctx)
+ if stk == nil {
return syserr.ErrNoDevice.ToError()
}
if ifc.Ptr == 0 {
- ifc.Len = int32(len(stack.Interfaces())) * int32(linux.SizeOfIFReq)
+ ifc.Len = int32(len(stk.Interfaces())) * int32(linux.SizeOfIFReq)
return nil
}
max := ifc.Len
ifc.Len = 0
- for key, ifaceAddrs := range stack.InterfaceAddrs() {
- iface := stack.Interfaces()[key]
+ for key, ifaceAddrs := range stk.InterfaceAddrs() {
+ iface := stk.Interfaces()[key]
for _, ifaceAddr := range ifaceAddrs {
// Don't write past the end of the buffer.
if ifc.Len+int32(linux.SizeOfIFReq) > max {
diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/sentry/socket/netstack/netstack_state.go
index 529e02a07..591e00d42 100644
--- a/pkg/tcpip/stack/iptables_state.go
+++ b/pkg/sentry/socket/netstack/netstack_state.go
@@ -1,4 +1,4 @@
-// Copyright 2020 The gVisor Authors.
+// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,29 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package stack
+package netstack
import (
"time"
)
-// +stateify savable
-type unixTime struct {
- second int64
- nano int64
+func (s *socketOpsCommon) saveTimestamp() int64 {
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ return s.timestamp.UnixNano()
}
-// saveLastUsed is invoked by stateify.
-func (cn *conn) saveLastUsed() unixTime {
- return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()}
-}
-
-// loadLastUsed is invoked by stateify.
-func (cn *conn) loadLastUsed(unix unixTime) {
- cn.lastUsed = time.Unix(unix.second, unix.nano)
-}
-
-// beforeSave is invoked by stateify.
-func (ct *ConnTrack) beforeSave() {
- ct.mu.Lock()
+func (s *socketOpsCommon) loadTimestamp(nsec int64) {
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ s.timestamp = time.Unix(0, nsec)
}
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index 208ab9909..ea199f223 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -155,7 +155,7 @@ func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
// Attach address to interface.
nicID := tcpip.NICID(idx)
- if err := s.Stack.AddProtocolAddressWithOptions(nicID, protocolAddress, stack.CanBePrimaryEndpoint); err != nil {
+ if err := s.Stack.AddProtocolAddress(nicID, protocolAddress, stack.AddressProperties{}); err != nil {
return syserr.TranslateNetstackError(err).ToError()
}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 841d5bd55..d4b80a39d 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -21,6 +21,7 @@ import (
"bytes"
"fmt"
"sync/atomic"
+ "time"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -51,8 +52,19 @@ type ControlMessages struct {
func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPacketInfo {
var p linux.ControlMessageIPPacketInfo
p.NIC = int32(packetInfo.NIC)
- copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
- copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
+ copy(p.LocalAddr[:], packetInfo.LocalAddr)
+ copy(p.DestinationAddr[:], packetInfo.DestinationAddr)
+ return p
+}
+
+// ipv6PacketInfoToLinux converts IPv6PacketInfo from tcpip format to Linux
+// format.
+func ipv6PacketInfoToLinux(packetInfo tcpip.IPv6PacketInfo) linux.ControlMessageIPv6PacketInfo {
+ var p linux.ControlMessageIPv6PacketInfo
+ if n := copy(p.Addr[:], packetInfo.Addr); n != len(p.Addr) {
+ panic(fmt.Sprintf("got copy(%x, %x) = %d, want = %d", p.Addr, packetInfo.Addr, n, len(p.Addr)))
+ }
+ p.NIC = uint32(packetInfo.NIC)
return p
}
@@ -114,7 +126,7 @@ func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessa
if cmgs.HasOriginalDstAddress {
orgDstAddr, _ = ConvertAddress(family, cmgs.OriginalDstAddress)
}
- return IPControlMessages{
+ cm := IPControlMessages{
HasTimestamp: cmgs.HasTimestamp,
Timestamp: cmgs.Timestamp,
HasInq: cmgs.HasInq,
@@ -125,9 +137,16 @@ func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessa
TClass: cmgs.TClass,
HasIPPacketInfo: cmgs.HasIPPacketInfo,
PacketInfo: packetInfoToLinux(cmgs.PacketInfo),
+ HasIPv6PacketInfo: cmgs.HasIPv6PacketInfo,
OriginalDstAddress: orgDstAddr,
SockErr: sockErrCmsgToLinux(cmgs.SockErr),
}
+
+ if cm.HasIPv6PacketInfo {
+ cm.IPv6PacketInfo = ipv6PacketInfoToLinux(cmgs.IPv6PacketInfo)
+ }
+
+ return cm
}
// IPControlMessages contains socket control messages for IP sockets.
@@ -138,9 +157,9 @@ type IPControlMessages struct {
// HasTimestamp indicates whether Timestamp is valid/set.
HasTimestamp bool
- // Timestamp is the time (in ns) that the last packet used to create
- // the read data was received.
- Timestamp int64
+ // Timestamp is the time that the last packet used to create the read data
+ // was received.
+ Timestamp time.Time `state:".(int64)"`
// HasInq indicates whether Inq is valid/set.
HasInq bool
@@ -166,6 +185,12 @@ type IPControlMessages struct {
// PacketInfo holds interface and address data on an incoming packet.
PacketInfo linux.ControlMessageIPPacketInfo
+ // HasIPv6PacketInfo indicates whether IPv6PacketInfo is set.
+ HasIPv6PacketInfo bool
+
+ // PacketInfo holds interface and address data on an incoming packet.
+ IPv6PacketInfo linux.ControlMessageIPv6PacketInfo
+
// OriginalDestinationAddress holds the original destination address
// and port of the incoming packet.
OriginalDstAddress linux.SockAddr
diff --git a/pkg/sentry/socket/socket_state.go b/pkg/sentry/socket/socket_state.go
new file mode 100644
index 000000000..32e12b238
--- /dev/null
+++ b/pkg/sentry/socket/socket_state.go
@@ -0,0 +1,27 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package socket
+
+import (
+ "time"
+)
+
+func (i *IPControlMessages) saveTimestamp() int64 {
+ return i.Timestamp.UnixNano()
+}
+
+func (i *IPControlMessages) loadTimestamp(nsec int64) {
+ i.Timestamp = time.Unix(0, nsec)
+}
diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go
index a9cedcf5f..188ad3bd9 100644
--- a/pkg/sentry/socket/unix/transport/queue.go
+++ b/pkg/sentry/socket/unix/transport/queue.go
@@ -59,12 +59,14 @@ func (q *queue) Close() {
// q.WriterQueue.Notify(waiter.WritableEvents)
func (q *queue) Reset(ctx context.Context) {
q.mu.Lock()
- for cur := q.dataList.Front(); cur != nil; cur = cur.Next() {
- cur.Release(ctx)
- }
+ dataList := q.dataList
q.dataList.Reset()
q.used = 0
q.mu.Unlock()
+
+ for cur := dataList.Front(); cur != nil; cur = cur.Next() {
+ cur.Release(ctx)
+ }
}
// DecRef implements RefCounter.DecRef.
diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go
index 757ff2a40..4d3f4d556 100644
--- a/pkg/sentry/strace/strace.go
+++ b/pkg/sentry/strace/strace.go
@@ -610,9 +610,9 @@ func (i *SyscallInfo) printExit(t *kernel.Task, elapsed time.Duration, output []
if err == nil {
// Fill in the output after successful execution.
i.post(t, args, retval, output, LogMaximumSize)
- rval = fmt.Sprintf("%#x (%v)", retval, elapsed)
+ rval = fmt.Sprintf("%d (%#x) (%v)", retval, retval, elapsed)
} else {
- rval = fmt.Sprintf("%#x errno=%d (%s) (%v)", retval, errno, err, elapsed)
+ rval = fmt.Sprintf("%d (%#x) errno=%d (%s) (%v)", retval, retval, errno, err, elapsed)
}
switch len(output) {
diff --git a/pkg/sentry/time/sampler_arm64.go b/pkg/sentry/time/sampler_arm64.go
index 3560e66ae..9b8c9a480 100644
--- a/pkg/sentry/time/sampler_arm64.go
+++ b/pkg/sentry/time/sampler_arm64.go
@@ -30,9 +30,9 @@ func getDefaultArchOverheadCycles() TSCValue {
// frqRatio. defaultOverheadCycles of ARM equals to that on
// x86 devided by frqRatio
cntfrq := getCNTFRQ()
- frqRatio := 1000000000 / cntfrq
+ frqRatio := 1000000000 / float64(cntfrq)
overheadCycles := (1 * 1000) / frqRatio
- return overheadCycles
+ return TSCValue(overheadCycles)
}
// defaultOverheadTSC is the default estimated syscall overhead in TSC cycles.
diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go
index e7073ec87..d9df890c4 100644
--- a/pkg/sentry/usage/memory.go
+++ b/pkg/sentry/usage/memory.go
@@ -252,9 +252,9 @@ func (m *MemoryLocked) Copy() (MemoryStats, uint64) {
return ms, m.totalLocked()
}
-// These options control how much total memory the is reported to the application.
-// They may only be set before the application starts executing, and must not
-// be modified.
+// These options control how much total memory the is reported to the
+// application. They may only be set before the application starts executing,
+// and must not be modified.
var (
// MinimumTotalMemoryBytes is the minimum reported total system memory.
MinimumTotalMemoryBytes uint64 = 2 << 30 // 2 GB
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
index 04bc4d10c..fefd0fc9c 100644
--- a/pkg/sentry/vfs/epoll.go
+++ b/pkg/sentry/vfs/epoll.go
@@ -135,12 +135,16 @@ func (ep *EpollInstance) Readiness(mask waiter.EventMask) waiter.EventMask {
return 0
}
ep.mu.Lock()
- for epi := ep.ready.Front(); epi != nil; epi = epi.Next() {
+ var next *epollInterest
+ for epi := ep.ready.Front(); epi != nil; epi = next {
+ next = epi.Next()
wmask := waiter.EventMaskFromLinux(epi.mask)
if epi.key.file.Readiness(wmask)&wmask != 0 {
ep.mu.Unlock()
return waiter.ReadableEvents
}
+ ep.ready.Remove(epi)
+ epi.ready = false
}
ep.mu.Unlock()
return 0
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index 5dab069ed..452f5f1f9 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -17,6 +17,7 @@ package vfs
import (
"bytes"
"io"
+ "math"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -399,6 +400,9 @@ func (fd *DynamicBytesFileDescriptionImpl) Write(ctx context.Context, src userme
// GenericConfigureMMap may be used by most implementations of
// FileDescriptionImpl.ConfigureMMap.
func GenericConfigureMMap(fd *FileDescription, m memmap.Mappable, opts *memmap.MMapOpts) error {
+ if opts.Offset+opts.Length > math.MaxInt64 {
+ return linuxerr.EOVERFLOW
+ }
opts.Mappable = m
opts.MappingIdentity = fd
fd.IncRef()
diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go
index 7fd7f000d..40aff2927 100644
--- a/pkg/sentry/vfs/resolving_path.go
+++ b/pkg/sentry/vfs/resolving_path.go
@@ -223,6 +223,12 @@ func (rp *ResolvingPath) Final() bool {
return rp.curPart == 0 && !rp.pit.NextOk()
}
+// Pit returns a copy of rp's current path iterator. Modifying the iterator
+// does not change rp.
+func (rp *ResolvingPath) Pit() fspath.Iterator {
+ return rp.pit
+}
+
// Component returns the current path component in the stream represented by
// rp.
//
diff --git a/pkg/sentry/vfs/save_restore.go b/pkg/sentry/vfs/save_restore.go
index 8998a82dd..8a6ced365 100644
--- a/pkg/sentry/vfs/save_restore.go
+++ b/pkg/sentry/vfs/save_restore.go
@@ -15,7 +15,6 @@
package vfs
import (
- "fmt"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -24,6 +23,18 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// ErrCorruption indicates a failed restore due to external file system state in
+// corruption.
+type ErrCorruption struct {
+ // Err is the wrapped error.
+ Err error
+}
+
+// Error returns a sensible description of the restore error.
+func (e ErrCorruption) Error() string {
+ return "restore failed due to external file system state in corruption: " + e.Err.Error()
+}
+
// FilesystemImplSaveRestoreExtension is an optional extension to
// FilesystemImpl.
type FilesystemImplSaveRestoreExtension interface {
@@ -37,38 +48,30 @@ type FilesystemImplSaveRestoreExtension interface {
// PrepareSave prepares all filesystems for serialization.
func (vfs *VirtualFilesystem) PrepareSave(ctx context.Context) error {
- failures := 0
for fs := range vfs.getFilesystems() {
if ext, ok := fs.impl.(FilesystemImplSaveRestoreExtension); ok {
if err := ext.PrepareSave(ctx); err != nil {
- ctx.Warningf("%T.PrepareSave failed: %v", fs.impl, err)
- failures++
+ fs.DecRef(ctx)
+ return err
}
}
fs.DecRef(ctx)
}
- if failures != 0 {
- return fmt.Errorf("%d filesystems failed to prepare for serialization", failures)
- }
return nil
}
// CompleteRestore completes restoration from checkpoint for all filesystems
// after deserialization.
func (vfs *VirtualFilesystem) CompleteRestore(ctx context.Context, opts *CompleteRestoreOptions) error {
- failures := 0
for fs := range vfs.getFilesystems() {
if ext, ok := fs.impl.(FilesystemImplSaveRestoreExtension); ok {
if err := ext.CompleteRestore(ctx, *opts); err != nil {
- ctx.Warningf("%T.CompleteRestore failed: %v", fs.impl, err)
- failures++
+ fs.DecRef(ctx)
+ return err
}
}
fs.DecRef(ctx)
}
- if failures != 0 {
- return fmt.Errorf("%d filesystems failed to complete restore after deserialization", failures)
- }
return nil
}
diff --git a/pkg/shim/service.go b/pkg/shim/service.go
index 24e3b7a82..0980d964e 100644
--- a/pkg/shim/service.go
+++ b/pkg/shim/service.go
@@ -77,6 +77,8 @@ const (
// shimAddressPath is the relative path to a file that contains the address
// to the shim UDS. See service.shimAddress.
shimAddressPath = "address"
+
+ cgroupParentAnnotation = "dev.gvisor.spec.cgroup-parent"
)
// New returns a new shim service that can be used via GRPC.
@@ -952,7 +954,7 @@ func newInit(path, workDir, namespace string, platform stdio.Platform, r *proc.C
if err != nil {
return nil, fmt.Errorf("update volume annotations: %w", err)
}
- updated = updateCgroup(spec) || updated
+ updated = setPodCgroup(spec) || updated
if updated {
if err := utils.WriteSpec(r.Bundle, spec); err != nil {
@@ -980,12 +982,13 @@ func newInit(path, workDir, namespace string, platform stdio.Platform, r *proc.C
return p, nil
}
-// updateCgroup updates cgroup path for the sandbox to make the sandbox join the
-// pod cgroup and not the pause container cgroup. Returns true if the spec was
-// modified. Ex.:
-// /kubepods/burstable/pod123/abc => kubepods/burstable/pod123
+// setPodCgroup searches for the pod cgroup path inside the container's cgroup
+// path. If found, it's set as an annotation in the spec. This is done so that
+// the sandbox joins the pod cgroup. Otherwise, the sandbox would join the pause
+// container cgroup. Returns true if the spec was modified. Ex.:
+// /kubepods/burstable/pod123/container123 => kubepods/burstable/pod123
//
-func updateCgroup(spec *specs.Spec) bool {
+func setPodCgroup(spec *specs.Spec) bool {
if !utils.IsSandbox(spec) {
return false
}
@@ -1009,7 +1012,10 @@ func updateCgroup(spec *specs.Spec) bool {
if spec.Linux.CgroupsPath == path {
return false
}
- spec.Linux.CgroupsPath = path
+ if spec.Annotations == nil {
+ spec.Annotations = make(map[string]string)
+ }
+ spec.Annotations[cgroupParentAnnotation] = path
return true
}
}
diff --git a/pkg/shim/service_test.go b/pkg/shim/service_test.go
index 2d9f07e02..4b4410a58 100644
--- a/pkg/shim/service_test.go
+++ b/pkg/shim/service_test.go
@@ -40,12 +40,12 @@ func TestCgroupPath(t *testing.T) {
{
name: "no-container",
path: "foo/pod123",
- want: "foo/pod123",
+ want: "",
},
{
name: "no-container-absolute",
path: "/foo/pod123",
- want: "/foo/pod123",
+ want: "",
},
{
name: "double-pod",
@@ -70,7 +70,7 @@ func TestCgroupPath(t *testing.T) {
{
name: "no-pod",
path: "/foo/nopod123/container",
- want: "/foo/nopod123/container",
+ want: "",
},
} {
t.Run(tc.name, func(t *testing.T) {
@@ -79,12 +79,12 @@ func TestCgroupPath(t *testing.T) {
CgroupsPath: tc.path,
},
}
- updated := updateCgroup(&spec)
- if spec.Linux.CgroupsPath != tc.want {
- t.Errorf("updateCgroup(%q), want: %q, got: %q", tc.path, tc.want, spec.Linux.CgroupsPath)
+ updated := setPodCgroup(&spec)
+ if got := spec.Annotations[cgroupParentAnnotation]; got != tc.want {
+ t.Errorf("setPodCgroup(%q), want: %q, got: %q", tc.path, tc.want, got)
}
- if shouldUpdate := tc.path != tc.want; shouldUpdate != updated {
- t.Errorf("updateCgroup(%q)=%v, want: %v", tc.path, updated, shouldUpdate)
+ if shouldUpdate := len(tc.want) > 0; shouldUpdate != updated {
+ t.Errorf("setPodCgroup(%q)=%v, want: %v", tc.path, updated, shouldUpdate)
}
})
}
@@ -113,8 +113,8 @@ func TestCgroupNoUpdate(t *testing.T) {
},
} {
t.Run(tc.name, func(t *testing.T) {
- if updated := updateCgroup(tc.spec); updated {
- t.Errorf("updateCgroup(%+v), got: %v, want: false", tc.spec.Linux, updated)
+ if updated := setPodCgroup(tc.spec); updated {
+ t.Errorf("setPodCgroup(%+v), got: %v, want: false", tc.spec.Linux, updated)
}
})
}
diff --git a/pkg/sentry/sighandling/BUILD b/pkg/sighandling/BUILD
index 1790d57c9..72f10f982 100644
--- a/pkg/sentry/sighandling/BUILD
+++ b/pkg/sighandling/BUILD
@@ -8,7 +8,7 @@ go_library(
"sighandling.go",
"sighandling_unsafe.go",
],
- visibility = ["//pkg/sentry:internal"],
+ visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/sentry/sighandling/sighandling.go b/pkg/sighandling/sighandling.go
index bdaf8af29..bdaf8af29 100644
--- a/pkg/sentry/sighandling/sighandling.go
+++ b/pkg/sighandling/sighandling.go
diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sighandling/sighandling_unsafe.go
index 3fe5c6770..7deeda042 100644
--- a/pkg/sentry/sighandling/sighandling_unsafe.go
+++ b/pkg/sighandling/sighandling_unsafe.go
@@ -15,6 +15,7 @@
package sighandling
import (
+ "fmt"
"unsafe"
"golang.org/x/sys/unix"
@@ -37,3 +38,36 @@ func IgnoreChildStop() error {
return nil
}
+
+// ReplaceSignalHandler replaces the existing signal handler for the provided
+// signal with the function pointer at `handler`. This bypasses the Go runtime
+// signal handlers, and should only be used for low-level signal handlers where
+// use of signal.Notify is not appropriate.
+//
+// It stores the value of the previously set handler in previous.
+func ReplaceSignalHandler(sig unix.Signal, handler uintptr, previous *uintptr) error {
+ var sa linux.SigAction
+ const maskLen = 8
+
+ // Get the existing signal handler information, and save the current
+ // handler. Once we replace it, we will use this pointer to fall back to
+ // it when we receive other signals.
+ if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), 0, uintptr(unsafe.Pointer(&sa)), maskLen, 0, 0); e != 0 {
+ return e
+ }
+
+ // Fail if there isn't a previous handler.
+ if sa.Handler == 0 {
+ return fmt.Errorf("previous handler for signal %x isn't set", sig)
+ }
+
+ *previous = uintptr(sa.Handler)
+
+ // Install our own handler.
+ sa.Handler = uint64(handler)
+ if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, maskLen, 0, 0); e != 0 {
+ return e
+ }
+
+ return nil
+}
diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD
index 73791b456..517f16329 100644
--- a/pkg/sync/BUILD
+++ b/pkg/sync/BUILD
@@ -26,6 +26,7 @@ go_library(
"rwmutex_unsafe.go",
"seqcount.go",
"sync.go",
+ "wait.go",
],
marshal = False,
stateify = False,
diff --git a/pkg/sync/atomicptr/generic_atomicptr_unsafe.go b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go
index 82b6df18c..7b9c2a4db 100644
--- a/pkg/sync/atomicptr/generic_atomicptr_unsafe.go
+++ b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go
@@ -37,6 +37,8 @@ func (p *AtomicPtr) loadPtr(v *Value) {
// Load returns the value set by the most recent Store. It returns nil if there
// has been no previous call to Store.
+//
+//go:nosplit
func (p *AtomicPtr) Load() *Value {
return (*Value)(atomic.LoadPointer(&p.ptr))
}
diff --git a/pkg/sync/wait.go b/pkg/sync/wait.go
new file mode 100644
index 000000000..f8e7742a5
--- /dev/null
+++ b/pkg/sync/wait.go
@@ -0,0 +1,58 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sync
+
+// WaitGroupErr is similar to WaitGroup but allows goroutines to report error.
+// Only the first error is retained and reported back.
+//
+// Example usage:
+// wg := WaitGroupErr{}
+// wg.Add(1)
+// go func() {
+// defer wg.Done()
+// if err := ...; err != nil {
+// wg.ReportError(err)
+// return
+// }
+// }()
+// return wg.Error()
+//
+type WaitGroupErr struct {
+ WaitGroup
+
+ // mu protects firstErr.
+ mu Mutex
+
+ // firstErr holds the first error reported. nil is no error occurred.
+ firstErr error
+}
+
+// ReportError reports an error. Note it does not call Done().
+func (w *WaitGroupErr) ReportError(err error) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ if w.firstErr == nil {
+ w.firstErr = err
+ }
+}
+
+// Error waits for the counter to reach 0 and returns the first reported error
+// if any.
+func (w *WaitGroupErr) Error() error {
+ w.Wait()
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ return w.firstErr
+}
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index dbe4506cc..b98de54c5 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -25,6 +25,7 @@ go_library(
"stdclock.go",
"stdclock_state.go",
"tcpip.go",
+ "tcpip_state.go",
"timer.go",
],
visibility = ["//visibility:public"],
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 010e2e833..1f2bcaf65 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -19,6 +19,7 @@ import (
"bytes"
"context"
"errors"
+ "fmt"
"io"
"net"
"time"
@@ -471,9 +472,9 @@ func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtoc
return DialContextTCP(context.Background(), s, addr, network)
}
-// DialContextTCP creates a new TCPConn connected to the specified address
-// with the option of adding cancellation and timeouts.
-func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
+// DialTCPWithBind creates a new TCPConn connected to the specified
+// remoteAddress with its local address bound to localAddr.
+func DialTCPWithBind(ctx context.Context, s *stack.Stack, localAddr, remoteAddr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
// Create TCP endpoint, then connect.
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
@@ -494,7 +495,14 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress,
default:
}
- err = ep.Connect(addr)
+ // Bind before connect if requested.
+ if localAddr != (tcpip.FullAddress{}) {
+ if err = ep.Bind(localAddr); err != nil {
+ return nil, fmt.Errorf("ep.Bind(%+v) = %s", localAddr, err)
+ }
+ }
+
+ err = ep.Connect(remoteAddr)
if _, ok := err.(*tcpip.ErrConnectStarted); ok {
select {
case <-ctx.Done():
@@ -510,7 +518,7 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress,
return nil, &net.OpError{
Op: "connect",
Net: "tcp",
- Addr: fullToTCPAddr(addr),
+ Addr: fullToTCPAddr(remoteAddr),
Err: errors.New(err.String()),
}
}
@@ -518,6 +526,12 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress,
return NewTCPConn(&wq, ep), nil
}
+// DialContextTCP creates a new TCPConn connected to the specified address
+// with the option of adding cancellation and timeouts.
+func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
+ return DialTCPWithBind(ctx, s, tcpip.FullAddress{} /* localAddr */, addr /* remoteAddr */, network)
+}
+
// A UDPConn is a wrapper around a UDP tcpip.Endpoint that implements
// net.Conn and net.PacketConn.
type UDPConn struct {
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 48b24692b..dcc9fff17 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -137,7 +137,13 @@ func TestCloseReader(t *testing.T) {
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
if e != nil {
@@ -190,7 +196,13 @@ func TestCloseReaderWithForwarder(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
done := make(chan struct{})
@@ -244,7 +256,13 @@ func TestCloseRead(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
@@ -288,7 +306,13 @@ func TestCloseWrite(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
@@ -349,10 +373,22 @@ func TestUDPForwarder(t *testing.T) {
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip1)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr1, err)
+ }
ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
addr2 := tcpip.FullAddress{NICID, ip2, 11311}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr2, err)
+ }
done := make(chan struct{})
fwd := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
@@ -410,7 +446,13 @@ func TestDeadlineChange(t *testing.T) {
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
if e != nil {
@@ -465,10 +507,22 @@ func TestPacketConnTransfer(t *testing.T) {
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip1)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr1, err)
+ }
ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
addr2 := tcpip.FullAddress{NICID, ip2, 11311}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr2, err)
+ }
c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber)
if err != nil {
@@ -521,7 +575,13 @@ func TestConnectedPacketConnTransfer(t *testing.T) {
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber)
if err != nil {
@@ -565,24 +625,30 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ip.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, nil, nil, fmt.Errorf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
l, err := ListenTCP(s, addr, ipv4.ProtocolNumber)
if err != nil {
- return nil, nil, nil, fmt.Errorf("NewListener: %v", err)
+ return nil, nil, nil, fmt.Errorf("NewListener: %w", err)
}
c1, err = DialTCP(s, addr, ipv4.ProtocolNumber)
if err != nil {
l.Close()
- return nil, nil, nil, fmt.Errorf("DialTCP: %v", err)
+ return nil, nil, nil, fmt.Errorf("DialTCP: %w", err)
}
c2, err = l.Accept()
if err != nil {
l.Close()
c1.Close()
- return nil, nil, nil, fmt.Errorf("l.Accept: %v", err)
+ return nil, nil, nil, fmt.Errorf("l.Accept: %w", err)
}
stop = func() {
@@ -594,7 +660,7 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
if err := l.Close(); err != nil {
stop()
- return nil, nil, nil, fmt.Errorf("l.Close(): %v", err)
+ return nil, nil, nil, fmt.Errorf("l.Close(): %w", err)
}
return c1, c2, stop, nil
@@ -681,7 +747,13 @@ func TestDialContextTCPCanceled(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
@@ -703,7 +775,13 @@ func TestDialContextTCPTimeout(t *testing.T) {
}()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
- s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.Addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", NICID, protocolAddr, err)
+ }
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
time.Sleep(time.Second)
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 2f34bf8dd..24c2c3e6b 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -324,6 +324,19 @@ func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker {
}
}
+// ReceiveIPv6PacketInfo creates a checker that checks the IPv6PacketInfo field
+// in ControlMessages.
+func ReceiveIPv6PacketInfo(want tcpip.IPv6PacketInfo) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasIPv6PacketInfo {
+ t.Errorf("got cm.HasIPv6PacketInfo = %t, want = true", cm.HasIPv6PacketInfo)
+ } else if diff := cmp.Diff(want, cm.IPv6PacketInfo); diff != "" {
+ t.Errorf("IPv6PacketInfo mismatch (-want +got):\n%s", diff)
+ }
+ }
+}
+
// ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress
// field in ControlMessages.
func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker {
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index dcc549c7b..7baaf0d17 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -208,6 +208,15 @@ var IPv4EmptySubnet = func() tcpip.Subnet {
return subnet
}()
+// IPv4LoopbackSubnet is the loopback subnet for IPv4.
+var IPv4LoopbackSubnet = func() tcpip.Subnet {
+ subnet, err := tcpip.NewSubnet(tcpip.Address("\x7f\x00\x00\x00"), tcpip.AddressMask("\xff\x00\x00\x00"))
+ if err != nil {
+ panic(err)
+ }
+ return subnet
+}()
+
// IPVersion returns the version of IP used in the given packet. It returns -1
// if the packet is not large enough to contain the version field.
func IPVersion(b []byte) int {
diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go
index 1c913b5e1..80a9ad6be 100644
--- a/pkg/tcpip/header/parse/parse.go
+++ b/pkg/tcpip/header/parse/parse.go
@@ -110,6 +110,16 @@ traverseExtensions:
switch extHdr := extHdr.(type) {
case header.IPv6FragmentExtHdr:
+ if extHdr.IsAtomic() {
+ // This fragment extension header indicates that this packet is an
+ // atomic fragment. An atomic fragment is a fragment that contains
+ // all the data required to reassemble a full packet. As per RFC 6946,
+ // atomic fragments must not interfere with "normal" fragmented traffic
+ // so we skip processing the fragment instead of feeding it through the
+ // reassembly process below.
+ continue
+ }
+
if fragID == 0 && fragOffset == 0 && !fragMore {
fragID = extHdr.ID()
fragOffset = extHdr.FragmentOffset()
@@ -175,3 +185,61 @@ func TCP(pkt *stack.PacketBuffer) bool {
pkt.TransportProtocolNumber = header.TCPProtocolNumber
return ok
}
+
+// ICMPv4 populates the packet buffer's transport header with an ICMPv4 header,
+// if present.
+//
+// Returns true if an ICMPv4 header was successfully parsed.
+func ICMPv4(pkt *stack.PacketBuffer) bool {
+ if _, ok := pkt.TransportHeader().Consume(header.ICMPv4MinimumSize); ok {
+ pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
+ return true
+ }
+ return false
+}
+
+// ICMPv6 populates the packet buffer's transport header with an ICMPv4 header,
+// if present.
+//
+// Returns true if an ICMPv6 header was successfully parsed.
+func ICMPv6(pkt *stack.PacketBuffer) bool {
+ hdr, ok := pkt.Data().PullUp(header.ICMPv6MinimumSize)
+ if !ok {
+ return false
+ }
+
+ h := header.ICMPv6(hdr)
+ switch h.Type() {
+ case header.ICMPv6RouterSolicit,
+ header.ICMPv6RouterAdvert,
+ header.ICMPv6NeighborSolicit,
+ header.ICMPv6NeighborAdvert,
+ header.ICMPv6RedirectMsg:
+ size := pkt.Data().Size()
+ if _, ok := pkt.TransportHeader().Consume(size); !ok {
+ panic(fmt.Sprintf("expected to consume the full data of size = %d bytes into transport header", size))
+ }
+ case header.ICMPv6MulticastListenerQuery,
+ header.ICMPv6MulticastListenerReport,
+ header.ICMPv6MulticastListenerDone:
+ size := header.ICMPv6HeaderSize + header.MLDMinimumSize
+ if _, ok := pkt.TransportHeader().Consume(size); !ok {
+ return false
+ }
+ case header.ICMPv6DstUnreachable,
+ header.ICMPv6PacketTooBig,
+ header.ICMPv6TimeExceeded,
+ header.ICMPv6ParamProblem,
+ header.ICMPv6EchoRequest,
+ header.ICMPv6EchoReply:
+ fallthrough
+ default:
+ if _, ok := pkt.TransportHeader().Consume(header.ICMPv6MinimumSize); !ok {
+ // Checked above if the packet buffer holds at least the minimum size for
+ // an ICMPv6 packet.
+ panic(fmt.Sprintf("expected to consume %d bytes", header.ICMPv6MinimumSize))
+ }
+ }
+ pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
+ return true
+}
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
index 3ed0aa3fe..c67ca98ea 100644
--- a/pkg/tcpip/link/pipe/pipe.go
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -123,4 +123,6 @@ func (*Endpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber
}
// WriteRawPacket implements stack.LinkEndpoint.
-func (*Endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
+func (e *Endpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ return e.WritePacket(stack.RouteInfo{}, 0, pkt)
+}
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index 87a0b9a62..e53789d92 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -152,10 +152,22 @@ type PollEvent struct {
// no data is available, it will block in a poll() syscall until the file
// descriptor becomes readable.
func BlockingRead(fd int, b []byte) (int, tcpip.Error) {
+ n, err := BlockingReadUntranslated(fd, b)
+ if err != 0 {
+ return n, TranslateErrno(err)
+ }
+ return n, nil
+}
+
+// BlockingReadUntranslated reads from a file descriptor that is set up as
+// non-blocking. If no data is available, it will block in a poll() syscall
+// until the file descriptor becomes readable. It returns the raw unix.Errno
+// value returned by the underlying syscalls.
+func BlockingReadUntranslated(fd int, b []byte) (int, unix.Errno) {
for {
n, _, e := unix.RawSyscall(unix.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)))
if e == 0 {
- return int(n), nil
+ return int(n), 0
}
event := PollEvent{
@@ -165,7 +177,7 @@ func BlockingRead(fd int, b []byte) (int, tcpip.Error) {
_, e = BlockingPoll(&event, 1, nil)
if e != 0 && e != unix.EINTR {
- return 0, TranslateErrno(e)
+ return 0, e
}
}
}
diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD
index 4215ee852..af755473c 100644
--- a/pkg/tcpip/link/sharedmem/BUILD
+++ b/pkg/tcpip/link/sharedmem/BUILD
@@ -5,19 +5,27 @@ package(licenses = ["notice"])
go_library(
name = "sharedmem",
srcs = [
+ "queuepair.go",
"rx.go",
+ "server_rx.go",
+ "server_tx.go",
"sharedmem.go",
+ "sharedmem_server.go",
"sharedmem_unsafe.go",
"tx.go",
],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/cleanup",
+ "//pkg/eventfd",
"//pkg/log",
+ "//pkg/memutil",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/link/rawfile",
+ "//pkg/tcpip/link/sharedmem/pipe",
"//pkg/tcpip/link/sharedmem/queue",
"//pkg/tcpip/stack",
"@org_golang_x_sys//unix:go_default_library",
@@ -26,9 +34,7 @@ go_library(
go_test(
name = "sharedmem_test",
- srcs = [
- "sharedmem_test.go",
- ],
+ srcs = ["sharedmem_test.go"],
library = ":sharedmem",
deps = [
"//pkg/sync",
@@ -41,3 +47,22 @@ go_test(
"@org_golang_x_sys//unix:go_default_library",
],
)
+
+go_test(
+ name = "sharedmem_server_test",
+ size = "small",
+ srcs = ["sharedmem_server_test.go"],
+ deps = [
+ ":sharedmem",
+ "//pkg/tcpip",
+ "//pkg/tcpip/adapters/gonet",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/link/sharedmem/queue/rx.go b/pkg/tcpip/link/sharedmem/queue/rx.go
index 696e6c9e5..a78826ebc 100644
--- a/pkg/tcpip/link/sharedmem/queue/rx.go
+++ b/pkg/tcpip/link/sharedmem/queue/rx.go
@@ -119,7 +119,6 @@ func (r *Rx) PostBuffers(buffers []RxBuffer) bool {
}
r.tx.Flush()
-
return true
}
@@ -131,7 +130,6 @@ func (r *Rx) PostBuffers(buffers []RxBuffer) bool {
func (r *Rx) Dequeue(bufs []RxBuffer) ([]RxBuffer, uint32) {
for {
outBufs := bufs
-
// Pull the next descriptor from the rx pipe.
b := r.rx.Pull()
if b == nil {
diff --git a/pkg/tcpip/link/sharedmem/queuepair.go b/pkg/tcpip/link/sharedmem/queuepair.go
new file mode 100644
index 000000000..b12647fdd
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/queuepair.go
@@ -0,0 +1,199 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "fmt"
+ "io/ioutil"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/eventfd"
+)
+
+const (
+ // defaultQueueDataSize is the size of the shared memory data region that
+ // holds the scatter/gather buffers.
+ defaultQueueDataSize = 1 << 20 // 1MiB
+
+ // defaultQueuePipeSize is the size of the pipe that holds the packet descriptors.
+ //
+ // Assuming each packet data is approximately 1280 bytes (IPv6 Minimum MTU)
+ // then we can hold approximately 1024*1024/1280 ~ 819 packets in the data
+ // area. Which means the pipe needs to be big enough to hold 819
+ // descriptors.
+ //
+ // Each descriptor is approximately 8 (slot descriptor in pipe) +
+ // 16 (packet descriptor) + 12 (for buffer descriptor) assuming each packet is
+ // stored in exactly 1 buffer descriptor (see queue/tx.go and pipe/tx.go.)
+ //
+ // Which means we need approximately 36*819 ~ 29 KiB to store all packet
+ // descriptors. We could go with a 32 KiB pipe but to give it some slack in
+ // how the upper layer may make use of the scatter gather buffers we double
+ // this to hold enough descriptors.
+ defaultQueuePipeSize = 64 << 10 // 64KiB
+
+ // defaultSharedDataSize is the size of the sharedData region used to
+ // enable/disable notifications.
+ defaultSharedDataSize = 4 << 10 // 4KiB
+)
+
+// A QueuePair represents a pair of TX/RX queues.
+type QueuePair struct {
+ // txCfg is the QueueConfig to be used for transmit queue.
+ txCfg QueueConfig
+
+ // rxCfg is the QueueConfig to be used for receive queue.
+ rxCfg QueueConfig
+}
+
+// NewQueuePair creates a shared memory QueuePair.
+func NewQueuePair() (*QueuePair, error) {
+ txCfg, err := createQueueFDs(queueSizes{
+ dataSize: defaultQueueDataSize,
+ txPipeSize: defaultQueuePipeSize,
+ rxPipeSize: defaultQueuePipeSize,
+ sharedDataSize: defaultSharedDataSize,
+ })
+
+ if err != nil {
+ return nil, fmt.Errorf("failed to create tx queue: %s", err)
+ }
+
+ rxCfg, err := createQueueFDs(queueSizes{
+ dataSize: defaultQueueDataSize,
+ txPipeSize: defaultQueuePipeSize,
+ rxPipeSize: defaultQueuePipeSize,
+ sharedDataSize: defaultSharedDataSize,
+ })
+
+ if err != nil {
+ closeFDs(txCfg)
+ return nil, fmt.Errorf("failed to create rx queue: %s", err)
+ }
+
+ return &QueuePair{
+ txCfg: txCfg,
+ rxCfg: rxCfg,
+ }, nil
+}
+
+// Close closes underlying tx/rx queue fds.
+func (q *QueuePair) Close() {
+ closeFDs(q.txCfg)
+ closeFDs(q.rxCfg)
+}
+
+// TXQueueConfig returns the QueueConfig for the receive queue.
+func (q *QueuePair) TXQueueConfig() QueueConfig {
+ return q.txCfg
+}
+
+// RXQueueConfig returns the QueueConfig for the transmit queue.
+func (q *QueuePair) RXQueueConfig() QueueConfig {
+ return q.rxCfg
+}
+
+type queueSizes struct {
+ dataSize int64
+ txPipeSize int64
+ rxPipeSize int64
+ sharedDataSize int64
+}
+
+func createQueueFDs(s queueSizes) (QueueConfig, error) {
+ success := false
+ var eventFD eventfd.Eventfd
+ var dataFD, txPipeFD, rxPipeFD, sharedDataFD int
+ defer func() {
+ if success {
+ return
+ }
+ closeFDs(QueueConfig{
+ EventFD: eventFD,
+ DataFD: dataFD,
+ TxPipeFD: txPipeFD,
+ RxPipeFD: rxPipeFD,
+ SharedDataFD: sharedDataFD,
+ })
+ }()
+ eventFD, err := eventfd.Create()
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("eventfd failed: %v", err)
+ }
+ dataFD, err = createFile(s.dataSize, false)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create dataFD: %s", err)
+ }
+ txPipeFD, err = createFile(s.txPipeSize, true)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create txPipeFD: %s", err)
+ }
+ rxPipeFD, err = createFile(s.rxPipeSize, true)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create rxPipeFD: %s", err)
+ }
+ sharedDataFD, err = createFile(s.sharedDataSize, false)
+ if err != nil {
+ return QueueConfig{}, fmt.Errorf("failed to create sharedDataFD: %s", err)
+ }
+ success = true
+ return QueueConfig{
+ EventFD: eventFD,
+ DataFD: dataFD,
+ TxPipeFD: txPipeFD,
+ RxPipeFD: rxPipeFD,
+ SharedDataFD: sharedDataFD,
+ }, nil
+}
+
+func createFile(size int64, initQueue bool) (fd int, err error) {
+ const tmpDir = "/dev/shm/"
+ f, err := ioutil.TempFile(tmpDir, "sharedmem_test")
+ if err != nil {
+ return -1, fmt.Errorf("TempFile failed: %v", err)
+ }
+ defer f.Close()
+ unix.Unlink(f.Name())
+
+ if initQueue {
+ // Write the "slot-free" flag in the initial queue.
+ if _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0); err != nil {
+ return -1, fmt.Errorf("WriteAt failed: %v", err)
+ }
+ }
+
+ fd, err = unix.Dup(int(f.Fd()))
+ if err != nil {
+ return -1, fmt.Errorf("unix.Dup(%d) failed: %v", f.Fd(), err)
+ }
+
+ if err := unix.Ftruncate(fd, size); err != nil {
+ unix.Close(fd)
+ return -1, fmt.Errorf("ftruncate(%d, %d) failed: %v", fd, size, err)
+ }
+
+ return fd, nil
+}
+
+func closeFDs(c QueueConfig) {
+ unix.Close(c.DataFD)
+ c.EventFD.Close()
+ unix.Close(c.TxPipeFD)
+ unix.Close(c.RxPipeFD)
+ unix.Close(c.SharedDataFD)
+}
diff --git a/pkg/tcpip/link/sharedmem/rx.go b/pkg/tcpip/link/sharedmem/rx.go
index e882a128c..87747dcc7 100644
--- a/pkg/tcpip/link/sharedmem/rx.go
+++ b/pkg/tcpip/link/sharedmem/rx.go
@@ -21,7 +21,7 @@ import (
"sync/atomic"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.dev/gvisor/pkg/eventfd"
"gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
)
@@ -30,7 +30,7 @@ type rx struct {
data []byte
sharedData []byte
q queue.Rx
- eventFD int
+ eventFD eventfd.Eventfd
}
// init initializes all state needed by the rx queue based on the information
@@ -68,7 +68,7 @@ func (r *rx) init(mtu uint32, c *QueueConfig) error {
// Duplicate the eventFD so that caller can close it but we can still
// use it.
- efd, err := unix.Dup(c.EventFD)
+ efd, err := c.EventFD.Dup()
if err != nil {
unix.Munmap(txPipe)
unix.Munmap(rxPipe)
@@ -77,16 +77,6 @@ func (r *rx) init(mtu uint32, c *QueueConfig) error {
return err
}
- // Set the eventfd as non-blocking.
- if err := unix.SetNonblock(efd, true); err != nil {
- unix.Munmap(txPipe)
- unix.Munmap(rxPipe)
- unix.Munmap(data)
- unix.Munmap(sharedData)
- unix.Close(efd)
- return err
- }
-
// Initialize state based on buffers.
r.q.Init(txPipe, rxPipe, sharedDataPointer(sharedData))
r.data = data
@@ -105,7 +95,13 @@ func (r *rx) cleanup() {
unix.Munmap(r.data)
unix.Munmap(r.sharedData)
- unix.Close(r.eventFD)
+ r.eventFD.Close()
+}
+
+// notify writes to the tx.eventFD to indicate to the peer that there is data to
+// be read.
+func (r *rx) notify() {
+ r.eventFD.Notify()
}
// postAndReceive posts the provided buffers (if any), and then tries to read
@@ -122,8 +118,7 @@ func (r *rx) postAndReceive(b []queue.RxBuffer, stopRequested *uint32) ([]queue.
if len(b) != 0 && !r.q.PostBuffers(b) {
r.q.EnableNotification()
for !r.q.PostBuffers(b) {
- var tmp [8]byte
- rawfile.BlockingRead(r.eventFD, tmp[:])
+ r.eventFD.Wait()
if atomic.LoadUint32(stopRequested) != 0 {
r.q.DisableNotification()
return nil, 0
@@ -147,8 +142,7 @@ func (r *rx) postAndReceive(b []queue.RxBuffer, stopRequested *uint32) ([]queue.
}
// Wait for notification.
- var tmp [8]byte
- rawfile.BlockingRead(r.eventFD, tmp[:])
+ r.eventFD.Wait()
if atomic.LoadUint32(stopRequested) != 0 {
r.q.DisableNotification()
return nil, 0
diff --git a/pkg/tcpip/link/sharedmem/server_rx.go b/pkg/tcpip/link/sharedmem/server_rx.go
new file mode 100644
index 000000000..6ea21ffd1
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/server_rx.go
@@ -0,0 +1,142 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/eventfd"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
+)
+
+type serverRx struct {
+ // packetPipe represents the receive end of the pipe that carries the packet
+ // descriptors sent by the client.
+ packetPipe pipe.Rx
+
+ // completionPipe represents the transmit end of the pipe that will carry
+ // completion notifications from the server to the client.
+ completionPipe pipe.Tx
+
+ // data represents the buffer area where the packet payload is held.
+ data []byte
+
+ // eventFD is used to notify the peer when transmission is completed.
+ eventFD eventfd.Eventfd
+
+ // sharedData the memory region to use to enable/disable notifications.
+ sharedData []byte
+}
+
+// init initializes all state needed by the serverTx queue based on the
+// information provided.
+//
+// The caller always retains ownership of all file descriptors passed in. The
+// queue implementation will duplicate any that it may need in the future.
+func (s *serverRx) init(c *QueueConfig) error {
+ // Map in all buffers.
+ packetPipeMem, err := getBuffer(c.TxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu := cleanup.Make(func() { unix.Munmap(packetPipeMem) })
+ defer cu.Clean()
+
+ completionPipeMem, err := getBuffer(c.RxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(completionPipeMem) })
+
+ data, err := getBuffer(c.DataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(data) })
+
+ sharedData, err := getBuffer(c.SharedDataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(sharedData) })
+
+ // Duplicate the eventFD so that caller can close it but we can still
+ // use it.
+ efd, err := c.EventFD.Dup()
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { efd.Close() })
+
+ s.packetPipe.Init(packetPipeMem)
+ s.completionPipe.Init(completionPipeMem)
+ s.data = data
+ s.eventFD = efd
+ s.sharedData = sharedData
+
+ cu.Release()
+ return nil
+}
+
+func (s *serverRx) cleanup() {
+ unix.Munmap(s.packetPipe.Bytes())
+ unix.Munmap(s.completionPipe.Bytes())
+ unix.Munmap(s.data)
+ unix.Munmap(s.sharedData)
+ s.eventFD.Close()
+}
+
+// completionNotificationSize is size in bytes of a completion notification sent
+// on the completion queue after a transmitted packet has been handled.
+const completionNotificationSize = 8
+
+// receive receives a single packet from the packetPipe.
+func (s *serverRx) receive() []byte {
+ desc := s.packetPipe.Pull()
+ if desc == nil {
+ return nil
+ }
+
+ pktInfo := queue.DecodeTxPacketHeader(desc)
+ contents := make([]byte, 0, pktInfo.Size)
+ toCopy := pktInfo.Size
+ for i := 0; i < pktInfo.BufferCount; i++ {
+ txBuf := queue.DecodeTxBufferHeader(desc, i)
+ if txBuf.Size <= toCopy {
+ contents = append(contents, s.data[txBuf.Offset:][:txBuf.Size]...)
+ toCopy -= txBuf.Size
+ continue
+ }
+ contents = append(contents, s.data[txBuf.Offset:][:toCopy]...)
+ break
+ }
+
+ // Flush to let peer know that slots queued for transmission have been handled
+ // and its free to reuse the slots.
+ s.packetPipe.Flush()
+ // Encode packet completion.
+ b := s.completionPipe.Push(completionNotificationSize)
+ queue.EncodeTxCompletion(b, pktInfo.ID)
+ s.completionPipe.Flush()
+ return contents
+}
+
+func (s *serverRx) waitForPackets() {
+ s.eventFD.Wait()
+}
diff --git a/pkg/tcpip/link/sharedmem/server_tx.go b/pkg/tcpip/link/sharedmem/server_tx.go
new file mode 100644
index 000000000..13a82903f
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/server_tx.go
@@ -0,0 +1,175 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/eventfd"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
+)
+
+// serverTx represents the server end of the sharedmem queue and is used to send
+// packets to the peer in the buffers posted by the peer in the fillPipe.
+type serverTx struct {
+ // fillPipe represents the receive end of the pipe that carries the RxBuffers
+ // posted by the peer.
+ fillPipe pipe.Rx
+
+ // completionPipe represents the transmit end of the pipe that carries the
+ // descriptors for filled RxBuffers.
+ completionPipe pipe.Tx
+
+ // data represents the buffer area where the packet payload is held.
+ data []byte
+
+ // eventFD is used to notify the peer when fill requests are fulfilled.
+ eventFD eventfd.Eventfd
+
+ // sharedData the memory region to use to enable/disable notifications.
+ sharedData []byte
+}
+
+// init initializes all tstate needed by the serverTx queue based on the
+// information provided.
+//
+// The caller always retains ownership of all file descriptors passed in. The
+// queue implementation will duplicate any that it may need in the future.
+func (s *serverTx) init(c *QueueConfig) error {
+ // Map in all buffers.
+ fillPipeMem, err := getBuffer(c.TxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu := cleanup.Make(func() { unix.Munmap(fillPipeMem) })
+ defer cu.Clean()
+
+ completionPipeMem, err := getBuffer(c.RxPipeFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(completionPipeMem) })
+
+ data, err := getBuffer(c.DataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(data) })
+
+ sharedData, err := getBuffer(c.SharedDataFD)
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { unix.Munmap(sharedData) })
+
+ // Duplicate the eventFD so that caller can close it but we can still
+ // use it.
+ efd, err := c.EventFD.Dup()
+ if err != nil {
+ return err
+ }
+ cu.Add(func() { efd.Close() })
+
+ cu.Release()
+
+ s.fillPipe.Init(fillPipeMem)
+ s.completionPipe.Init(completionPipeMem)
+ s.data = data
+ s.eventFD = efd
+ s.sharedData = sharedData
+
+ return nil
+}
+
+func (s *serverTx) cleanup() {
+ unix.Munmap(s.fillPipe.Bytes())
+ unix.Munmap(s.completionPipe.Bytes())
+ unix.Munmap(s.data)
+ unix.Munmap(s.sharedData)
+ s.eventFD.Close()
+}
+
+// fillPacket copies the data in the provided views into buffers pulled from the
+// fillPipe and returns a slice of RxBuffers that contain the copied data as
+// well as the total number of bytes copied.
+//
+// To avoid allocations the filledBuffers are appended to the buffers slice
+// which will be grown as required.
+func (s *serverTx) fillPacket(views []buffer.View, buffers []queue.RxBuffer) (filledBuffers []queue.RxBuffer, totalCopied uint32) {
+ filledBuffers = buffers[:0]
+ // fillBuffer copies as much of the views as possible into the provided buffer
+ // and returns any left over views (if any).
+ fillBuffer := func(buffer *queue.RxBuffer, views []buffer.View) (left []buffer.View) {
+ if len(views) == 0 {
+ return nil
+ }
+ availBytes := buffer.Size
+ copied := uint64(0)
+ for availBytes > 0 && len(views) > 0 {
+ n := copy(s.data[buffer.Offset+copied:][:uint64(buffer.Size)-copied], views[0])
+ views[0].TrimFront(n)
+ if !views[0].IsEmpty() {
+ break
+ }
+ views = views[1:]
+ copied += uint64(n)
+ availBytes -= uint32(n)
+ }
+ buffer.Size = uint32(copied)
+ return views
+ }
+
+ for len(views) > 0 {
+ var b []byte
+ // Spin till we get a free buffer reposted by the peer.
+ for {
+ if b = s.fillPipe.Pull(); b != nil {
+ break
+ }
+ }
+ rxBuffer := queue.DecodeRxBufferHeader(b)
+ // Copy the packet into the posted buffer.
+ views = fillBuffer(&rxBuffer, views)
+ totalCopied += rxBuffer.Size
+ filledBuffers = append(filledBuffers, rxBuffer)
+ }
+
+ return filledBuffers, totalCopied
+}
+
+func (s *serverTx) transmit(views []buffer.View) bool {
+ buffers := make([]queue.RxBuffer, 8)
+ buffers, totalCopied := s.fillPacket(views, buffers)
+ b := s.completionPipe.Push(queue.RxCompletionSize(len(buffers)))
+ if b == nil {
+ return false
+ }
+ queue.EncodeRxCompletion(b, totalCopied, 0 /* reserved */)
+ for i := 0; i < len(buffers); i++ {
+ queue.EncodeRxCompletionBuffer(b, i, buffers[i])
+ }
+ s.completionPipe.Flush()
+ s.fillPipe.Flush()
+ return true
+}
+
+func (s *serverTx) notify() {
+ s.eventFD.Notify()
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 66efe6472..b75522a51 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -24,14 +24,16 @@
package sharedmem
import (
+ "fmt"
"sync/atomic"
- "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/eventfd"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
"gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -47,7 +49,7 @@ type QueueConfig struct {
// EventFD is a file descriptor for the event that is signaled when
// data is becomes available in this queue.
- EventFD int
+ EventFD eventfd.Eventfd
// TxPipeFD is a file descriptor for the tx pipe associated with the
// queue.
@@ -63,16 +65,97 @@ type QueueConfig struct {
SharedDataFD int
}
+// FDs returns the FD's in the QueueConfig as a slice of ints. This must
+// be used in conjunction with QueueConfigFromFDs to ensure the order
+// of FDs matches when reconstructing the config when serialized or sent
+// as part of control messages.
+func (q *QueueConfig) FDs() []int {
+ return []int{q.DataFD, q.EventFD.FD(), q.TxPipeFD, q.RxPipeFD, q.SharedDataFD}
+}
+
+// QueueConfigFromFDs constructs a QueueConfig out of a slice of ints where each
+// entry represents an file descriptor. The order of FDs in the slice must be in
+// the order specified below for the config to be valid. QueueConfig.FDs()
+// should be used when the config needs to be serialized or sent as part of a
+// control message to ensure the correct order.
+func QueueConfigFromFDs(fds []int) (QueueConfig, error) {
+ if len(fds) != 5 {
+ return QueueConfig{}, fmt.Errorf("insufficient number of fds: len(fds): %d, want: 5", len(fds))
+ }
+ return QueueConfig{
+ DataFD: fds[0],
+ EventFD: eventfd.Wrap(fds[1]),
+ TxPipeFD: fds[2],
+ RxPipeFD: fds[3],
+ SharedDataFD: fds[4],
+ }, nil
+}
+
+// Options specify the details about the sharedmem endpoint to be created.
+type Options struct {
+ // MTU is the mtu to use for this endpoint.
+ MTU uint32
+
+ // BufferSize is the size of each scatter/gather buffer that will hold packet
+ // data.
+ //
+ // NOTE: This directly determines number of packets that can be held in
+ // the ring buffer at any time. This does not have to be sized to the MTU as
+ // the shared memory queue design allows usage of more than one buffer to be
+ // used to make up a given packet.
+ BufferSize uint32
+
+ // LinkAddress is the link address for this endpoint (required).
+ LinkAddress tcpip.LinkAddress
+
+ // TX is the transmit queue configuration for this shared memory endpoint.
+ TX QueueConfig
+
+ // RX is the receive queue configuration for this shared memory endpoint.
+ RX QueueConfig
+
+ // PeerFD is the fd for the connected peer which can be used to detect
+ // peer disconnects.
+ PeerFD int
+
+ // OnClosed is a function that is called when the endpoint is being closed
+ // (probably due to peer going away)
+ OnClosed func(err tcpip.Error)
+
+ // TXChecksumOffload if true, indicates that this endpoints capability
+ // set should include CapabilityTXChecksumOffload.
+ TXChecksumOffload bool
+
+ // RXChecksumOffload if true, indicates that this endpoints capability
+ // set should include CapabilityRXChecksumOffload.
+ RXChecksumOffload bool
+}
+
type endpoint struct {
// mtu (maximum transmission unit) is the maximum size of a packet.
+ // mtu is immutable.
mtu uint32
// bufferSize is the size of each individual buffer.
+ // bufferSize is immutable.
bufferSize uint32
// addr is the local address of this endpoint.
+ // addr is immutable.
addr tcpip.LinkAddress
+ // peerFD is an fd to the peer that can be used to detect when the
+ // peer is gone.
+ // peerFD is immutable.
+ peerFD int
+
+ // caps holds the endpoint capabilities.
+ caps stack.LinkEndpointCapabilities
+
+ // hdrSize is the size of the link layer header if any.
+ // hdrSize is immutable.
+ hdrSize uint32
+
// rx is the receive queue.
rx rx
@@ -83,34 +166,55 @@ type endpoint struct {
// Wait group used to indicate that all workers have stopped.
completed sync.WaitGroup
+ // onClosed is a function to be called when the FD's peer (if any) closes
+ // its end of the communication pipe.
+ onClosed func(tcpip.Error)
+
// mu protects the following fields.
mu sync.Mutex
// tx is the transmit queue.
+ // +checklocks:mu
tx tx
// workerStarted specifies whether the worker goroutine was started.
+ // +checklocks:mu
workerStarted bool
}
// New creates a new shared-memory-based endpoint. Buffers will be broken up
// into buffers of "bufferSize" bytes.
-func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (stack.LinkEndpoint, error) {
+func New(opts Options) (stack.LinkEndpoint, error) {
e := &endpoint{
- mtu: mtu,
- bufferSize: bufferSize,
- addr: addr,
+ mtu: opts.MTU,
+ bufferSize: opts.BufferSize,
+ addr: opts.LinkAddress,
+ peerFD: opts.PeerFD,
+ onClosed: opts.OnClosed,
}
- if err := e.tx.init(bufferSize, &tx); err != nil {
+ if err := e.tx.init(opts.BufferSize, &opts.TX); err != nil {
return nil, err
}
- if err := e.rx.init(bufferSize, &rx); err != nil {
+ if err := e.rx.init(opts.BufferSize, &opts.RX); err != nil {
e.tx.cleanup()
return nil, err
}
+ e.caps = stack.LinkEndpointCapabilities(0)
+ if opts.RXChecksumOffload {
+ e.caps |= stack.CapabilityRXChecksumOffload
+ }
+
+ if opts.TXChecksumOffload {
+ e.caps |= stack.CapabilityTXChecksumOffload
+ }
+
+ if opts.LinkAddress != "" {
+ e.hdrSize = header.EthernetMinimumSize
+ e.caps |= stack.CapabilityResolutionRequired
+ }
return e, nil
}
@@ -119,13 +223,13 @@ func (e *endpoint) Close() {
// Tell dispatch goroutine to stop, then write to the eventfd so that
// it wakes up in case it's sleeping.
atomic.StoreUint32(&e.stopRequested, 1)
- unix.Write(e.rx.eventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ e.rx.eventFD.Notify()
// Cleanup the queues inline if the worker hasn't started yet; we also
// know it won't start from now on because stopRequested is set to 1.
e.mu.Lock()
+ defer e.mu.Unlock()
workerPresent := e.workerStarted
- e.mu.Unlock()
if !workerPresent {
e.tx.cleanup()
@@ -146,6 +250,22 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 {
e.workerStarted = true
e.completed.Add(1)
+
+ // Spin up a goroutine to monitor for peer shutdown.
+ if e.peerFD >= 0 {
+ e.completed.Add(1)
+ go func() {
+ defer e.completed.Done()
+ b := make([]byte, 1)
+ // When sharedmem endpoint is in use the peerFD is never used for any data
+ // transfer and this Read should only return if the peer is shutting down.
+ _, err := rawfile.BlockingRead(e.peerFD, b)
+ if e.onClosed != nil {
+ e.onClosed(err)
+ }
+ }()
+ }
+
// Link endpoints are not savable. When transportation endpoints
// are saved, they stop sending outgoing packets and all
// incoming packets are rejected.
@@ -164,18 +284,18 @@ func (e *endpoint) IsAttached() bool {
// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
// during construction.
func (e *endpoint) MTU() uint32 {
- return e.mtu - header.EthernetMinimumSize
+ return e.mtu - e.hdrSize
}
// Capabilities implements stack.LinkEndpoint.Capabilities.
-func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return 0
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.caps
}
// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the
// ethernet frame header size.
-func (*endpoint) MaxHeaderLength() uint16 {
- return header.EthernetMinimumSize
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return uint16(e.hdrSize)
}
// LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local
@@ -205,17 +325,15 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
// WriteRawPacket implements stack.LinkEndpoint.
func (*endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} }
-// WritePacket writes outbound packets to the file descriptor. If it is not
-// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+// +checklocks:e.mu
+func (e *endpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ if e.addr != "" {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ }
views := pkt.Views()
// Transmit the packet.
- e.mu.Lock()
ok := e.tx.transmit(views...)
- e.mu.Unlock()
-
if !ok {
return &tcpip.ErrWouldBlock{}
}
@@ -223,9 +341,37 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol
return nil
}
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if err := e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil {
+ return err
+ }
+ e.tx.notify()
+ return nil
+}
+
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (*endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- panic("not implemented")
+func (e *endpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+ n := 0
+ var err tcpip.Error
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err = e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil {
+ break
+ }
+ n++
+ }
+ // WritePackets never returns an error if it successfully transmitted at least
+ // one packet.
+ if err != nil && n == 0 {
+ return 0, err
+ }
+ e.tx.notify()
+ return n, nil
}
// dispatchLoop reads packets from the rx queue in a loop and dispatches them
@@ -268,16 +414,42 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
Data: buffer.View(b).ToVectorisedView(),
})
- hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
- if !ok {
- continue
+ var src, dst tcpip.LinkAddress
+ var proto tcpip.NetworkProtocolNumber
+ if e.addr != "" {
+ hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
+ if !ok {
+ continue
+ }
+ eth := header.Ethernet(hdr)
+ src = eth.SourceAddress()
+ dst = eth.DestinationAddress()
+ proto = eth.Type()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ // IP version information is at the first octet, so pulling up 1 byte.
+ h, ok := pkt.Data().PullUp(1)
+ if !ok {
+ continue
+ }
+ switch header.IPVersion(h) {
+ case header.IPv4Version:
+ proto = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ proto = header.IPv6ProtocolNumber
+ default:
+ continue
+ }
}
- eth := header.Ethernet(hdr)
// Send packet up the stack.
- d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), pkt)
+ d.DeliverNetworkPacket(src, dst, proto, pkt)
}
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
// Clean state.
e.tx.cleanup()
e.rx.cleanup()
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_server.go b/pkg/tcpip/link/sharedmem/sharedmem_server.go
new file mode 100644
index 000000000..43c5b8c63
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/sharedmem_server.go
@@ -0,0 +1,344 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type serverEndpoint struct {
+ // mtu (maximum transmission unit) is the maximum size of a packet.
+ // mtu is immutable.
+ mtu uint32
+
+ // bufferSize is the size of each individual buffer.
+ // bufferSize is immutable.
+ bufferSize uint32
+
+ // addr is the local address of this endpoint.
+ // addr is immutable
+ addr tcpip.LinkAddress
+
+ // rx is the receive queue.
+ rx serverRx
+
+ // stopRequested is to be accessed atomically only, and determines if the
+ // worker goroutines should stop.
+ stopRequested uint32
+
+ // Wait group used to indicate that all workers have stopped.
+ completed sync.WaitGroup
+
+ // peerFD is an fd to the peer that can be used to detect when the peer is
+ // gone.
+ // peerFD is immutable.
+ peerFD int
+
+ // caps holds the endpoint capabilities.
+ caps stack.LinkEndpointCapabilities
+
+ // hdrSize is the size of the link layer header if any.
+ // hdrSize is immutable.
+ hdrSize uint32
+
+ // onClosed is a function to be called when the FD's peer (if any) closes its
+ // end of the communication pipe.
+ onClosed func(tcpip.Error)
+
+ // mu protects the following fields.
+ mu sync.Mutex
+
+ // tx is the transmit queue.
+ // +checklocks:mu
+ tx serverTx
+
+ // workerStarted specifies whether the worker goroutine was started.
+ // +checklocks:mu
+ workerStarted bool
+}
+
+// NewServerEndpoint creates a new shared-memory-based endpoint. Buffers will be
+// broken up into buffers of "bufferSize" bytes.
+func NewServerEndpoint(opts Options) (stack.LinkEndpoint, error) {
+ e := &serverEndpoint{
+ mtu: opts.MTU,
+ bufferSize: opts.BufferSize,
+ addr: opts.LinkAddress,
+ peerFD: opts.PeerFD,
+ onClosed: opts.OnClosed,
+ }
+
+ if err := e.tx.init(&opts.RX); err != nil {
+ return nil, err
+ }
+
+ if err := e.rx.init(&opts.TX); err != nil {
+ e.tx.cleanup()
+ return nil, err
+ }
+
+ e.caps = stack.LinkEndpointCapabilities(0)
+ if opts.RXChecksumOffload {
+ e.caps |= stack.CapabilityRXChecksumOffload
+ }
+
+ if opts.TXChecksumOffload {
+ e.caps |= stack.CapabilityTXChecksumOffload
+ }
+
+ if opts.LinkAddress != "" {
+ e.hdrSize = header.EthernetMinimumSize
+ e.caps |= stack.CapabilityResolutionRequired
+ }
+
+ return e, nil
+}
+
+// Close frees all resources associated with the endpoint.
+func (e *serverEndpoint) Close() {
+ // Tell dispatch goroutine to stop, then write to the eventfd so that it wakes
+ // up in case it's sleeping.
+ atomic.StoreUint32(&e.stopRequested, 1)
+ e.rx.eventFD.Notify()
+
+ // Cleanup the queues inline if the worker hasn't started yet; we also know it
+ // won't start from now on because stopRequested is set to 1.
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ workerPresent := e.workerStarted
+
+ if !workerPresent {
+ e.tx.cleanup()
+ e.rx.cleanup()
+ }
+}
+
+// Wait implements stack.LinkEndpoint.Wait. It waits until all workers have
+// stopped after a Close() call.
+func (e *serverEndpoint) Wait() {
+ e.completed.Wait()
+}
+
+// Attach implements stack.LinkEndpoint.Attach. It launches the goroutine that
+// reads packets from the rx queue.
+func (e *serverEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.mu.Lock()
+ if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 {
+ e.workerStarted = true
+ e.completed.Add(1)
+ if e.peerFD >= 0 {
+ e.completed.Add(1)
+ // Spin up a goroutine to monitor for peer shutdown.
+ go func() {
+ b := make([]byte, 1)
+ // When sharedmem endpoint is in use the peerFD is never used for any
+ // data transfer and this Read should only return if the peer is
+ // shutting down.
+ _, err := rawfile.BlockingRead(e.peerFD, b)
+ if e.onClosed != nil {
+ e.onClosed(err)
+ }
+ e.completed.Done()
+ }()
+ }
+ // Link endpoints are not savable. When transportation endpoints are saved,
+ // they stop sending outgoing packets and all incoming packets are rejected.
+ go e.dispatchLoop(dispatcher) // S/R-SAFE: see above.
+ }
+ e.mu.Unlock()
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *serverEndpoint) IsAttached() bool {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.workerStarted
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
+// during construction.
+func (e *serverEndpoint) MTU() uint32 {
+ return e.mtu - e.hdrSize
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (e *serverEndpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.caps
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the
+// ethernet frame header size.
+func (e *serverEndpoint) MaxHeaderLength() uint16 {
+ return uint16(e.hdrSize)
+}
+
+// LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local
+// link address.
+func (e *serverEndpoint) LinkAddress() tcpip.LinkAddress {
+ return e.addr
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *serverEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ // Add ethernet header if needed.
+ eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
+ ethHdr := &header.EthernetFields{
+ DstAddr: remote,
+ Type: protocol,
+ }
+
+ // Preserve the src address if it's set in the route.
+ if local != "" {
+ ethHdr.SrcAddr = local
+ } else {
+ ethHdr.SrcAddr = e.addr
+ }
+ eth.Encode(ethHdr)
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket
+func (e *serverEndpoint) WriteRawPacket(pkt *stack.PacketBuffer) tcpip.Error {
+ views := pkt.Views()
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ ok := e.tx.transmit(views)
+ if !ok {
+ return &tcpip.ErrWouldBlock{}
+ }
+ e.tx.notify()
+ return nil
+}
+
+// +checklocks:e.mu
+func (e *serverEndpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ if e.addr != "" {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ }
+
+ views := pkt.Views()
+ ok := e.tx.transmit(views)
+ if !ok {
+ return &tcpip.ErrWouldBlock{}
+ }
+
+ return nil
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+// WritePacket implements stack.LinkEndpoint.WritePacket.
+func (e *serverEndpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ // Transmit the packet.
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if err := e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil {
+ return err
+ }
+ e.tx.notify()
+ return nil
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *serverEndpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+ n := 0
+ var err tcpip.Error
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err = e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil {
+ break
+ }
+ n++
+ }
+ // WritePackets never returns an error if it successfully transmitted at least
+ // one packet.
+ if err != nil && n == 0 {
+ return 0, err
+ }
+ e.tx.notify()
+ return n, nil
+}
+
+// dispatchLoop reads packets from the rx queue in a loop and dispatches them
+// to the network stack.
+func (e *serverEndpoint) dispatchLoop(d stack.NetworkDispatcher) {
+ for atomic.LoadUint32(&e.stopRequested) == 0 {
+ b := e.rx.receive()
+ if b == nil {
+ e.rx.waitForPackets()
+ continue
+ }
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.View(b).ToVectorisedView(),
+ })
+ var src, dst tcpip.LinkAddress
+ var proto tcpip.NetworkProtocolNumber
+ if e.addr != "" {
+ hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
+ if !ok {
+ continue
+ }
+ eth := header.Ethernet(hdr)
+ src = eth.SourceAddress()
+ dst = eth.DestinationAddress()
+ proto = eth.Type()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ // IP version information is at the first octet, so pulling up 1 byte.
+ h, ok := pkt.Data().PullUp(1)
+ if !ok {
+ continue
+ }
+ switch header.IPVersion(h) {
+ case header.IPv4Version:
+ proto = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ proto = header.IPv6ProtocolNumber
+ default:
+ continue
+ }
+ }
+ // Send packet up the stack.
+ d.DeliverNetworkPacket(src, dst, proto, pkt)
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Clean state.
+ e.tx.cleanup()
+ e.rx.cleanup()
+
+ e.completed.Done()
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
+func (e *serverEndpoint) ARPHardwareType() header.ARPHardwareType {
+ if e.hdrSize > 0 {
+ return header.ARPHardwareEther
+ }
+ return header.ARPHardwareNone
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_server_test.go b/pkg/tcpip/link/sharedmem/sharedmem_server_test.go
new file mode 100644
index 000000000..1bc58614e
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/sharedmem_server_test.go
@@ -0,0 +1,220 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build linux
+// +build linux
+
+package sharedmem_server_test
+
+import (
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "syscall"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+)
+
+const (
+ localLinkAddr = "\xde\xad\xbe\xef\x56\x78"
+ remoteLinkAddr = "\xde\xad\xbe\xef\x12\x34"
+ localIPv4Address = tcpip.Address("\x0a\x00\x00\x01")
+ remoteIPv4Address = tcpip.Address("\x0a\x00\x00\x02")
+ serverPort = 10001
+
+ defaultMTU = 1500
+ defaultBufferSize = 1500
+)
+
+type stackOptions struct {
+ ep stack.LinkEndpoint
+ addr tcpip.Address
+}
+
+func newStackWithOptions(stackOpts stackOptions) (*stack.Stack, error) {
+ st := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocolWithOptions(ipv4.Options{
+ AllowExternalLoopbackTraffic: true,
+ }),
+ ipv6.NewProtocolWithOptions(ipv6.Options{
+ AllowExternalLoopbackTraffic: true,
+ }),
+ },
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
+ })
+ nicID := tcpip.NICID(1)
+ sniffEP := sniffer.New(stackOpts.ep)
+ opts := stack.NICOptions{Name: "eth0"}
+ if err := st.CreateNICWithOptions(nicID, sniffEP, opts); err != nil {
+ return nil, fmt.Errorf("method CreateNICWithOptions(%d, _, %v) failed: %s", nicID, opts, err)
+ }
+
+ // Add Protocol Address.
+ protocolNum := ipv4.ProtocolNumber
+ routeTable := []tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}
+ if len(stackOpts.addr) == 16 {
+ routeTable = []tcpip.Route{{Destination: header.IPv6EmptySubnet, NIC: nicID}}
+ protocolNum = ipv6.ProtocolNumber
+ }
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: protocolNum,
+ AddressWithPrefix: stackOpts.addr.WithPrefix(),
+ }
+ if err := st.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, fmt.Errorf("AddProtocolAddress(%d, %v, {}): %s", nicID, protocolAddr, err)
+ }
+
+ // Setup route table.
+ st.SetRouteTable(routeTable)
+
+ return st, nil
+}
+
+func newClientStack(t *testing.T, qPair *sharedmem.QueuePair, peerFD int) (*stack.Stack, error) {
+ ep, err := sharedmem.New(sharedmem.Options{
+ MTU: defaultMTU,
+ BufferSize: defaultBufferSize,
+ LinkAddress: localLinkAddr,
+ TX: qPair.TXQueueConfig(),
+ RX: qPair.RXQueueConfig(),
+ PeerFD: peerFD,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed to create sharedmem endpoint: %s", err)
+ }
+ st, err := newStackWithOptions(stackOptions{ep: ep, addr: localIPv4Address})
+ if err != nil {
+ return nil, fmt.Errorf("failed to create client stack: %s", err)
+ }
+ return st, nil
+}
+
+func newServerStack(t *testing.T, qPair *sharedmem.QueuePair, peerFD int) (*stack.Stack, error) {
+ ep, err := sharedmem.NewServerEndpoint(sharedmem.Options{
+ MTU: defaultMTU,
+ BufferSize: defaultBufferSize,
+ LinkAddress: remoteLinkAddr,
+ TX: qPair.TXQueueConfig(),
+ RX: qPair.RXQueueConfig(),
+ PeerFD: peerFD,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed to create sharedmem endpoint: %s", err)
+ }
+ st, err := newStackWithOptions(stackOptions{ep: ep, addr: remoteIPv4Address})
+ if err != nil {
+ return nil, fmt.Errorf("failed to create client stack: %s", err)
+ }
+ return st, nil
+}
+
+type testContext struct {
+ clientStk *stack.Stack
+ serverStk *stack.Stack
+ peerFDs [2]int
+}
+
+func newTestContext(t *testing.T) *testContext {
+ peerFDs, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET|syscall.SOCK_NONBLOCK, 0)
+ if err != nil {
+ t.Fatalf("failed to create peerFDs: %s", err)
+ }
+ q, err := sharedmem.NewQueuePair()
+ if err != nil {
+ t.Fatalf("failed to create sharedmem queue: %s", err)
+ }
+ clientStack, err := newClientStack(t, q, peerFDs[0])
+ if err != nil {
+ q.Close()
+ unix.Close(peerFDs[0])
+ unix.Close(peerFDs[1])
+ t.Fatalf("failed to create client stack: %s", err)
+ }
+ serverStack, err := newServerStack(t, q, peerFDs[1])
+ if err != nil {
+ q.Close()
+ unix.Close(peerFDs[0])
+ unix.Close(peerFDs[1])
+ clientStack.Close()
+ t.Fatalf("failed to create server stack: %s", err)
+ }
+ return &testContext{
+ clientStk: clientStack,
+ serverStk: serverStack,
+ peerFDs: peerFDs,
+ }
+}
+
+func (ctx *testContext) cleanup() {
+ unix.Close(ctx.peerFDs[0])
+ unix.Close(ctx.peerFDs[1])
+ ctx.clientStk.Close()
+ ctx.serverStk.Close()
+}
+
+func TestServerRoundTrip(t *testing.T) {
+ ctx := newTestContext(t)
+ defer ctx.cleanup()
+ listenAddr := tcpip.FullAddress{Addr: remoteIPv4Address, Port: serverPort}
+ l, err := gonet.ListenTCP(ctx.serverStk, listenAddr, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatalf("failed to start TCP Listener: %s", err)
+ }
+ defer l.Close()
+ var responseString = "response"
+ go func() {
+ http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(responseString))
+ }))
+ }()
+
+ dialFunc := func(address, protocol string) (net.Conn, error) {
+ return gonet.DialTCP(ctx.clientStk, listenAddr, ipv4.ProtocolNumber)
+ }
+
+ httpClient := &http.Client{
+ Transport: &http.Transport{
+ Dial: dialFunc,
+ },
+ }
+ serverURL := fmt.Sprintf("http://[%s]:%d/", net.IP(remoteIPv4Address), serverPort)
+ response, err := httpClient.Get(serverURL)
+ if err != nil {
+ t.Fatalf("httpClient.Get(\"/\") failed: %s", err)
+ }
+ if got, want := response.StatusCode, http.StatusOK; got != want {
+ t.Fatalf("unexpected status code got: %d, want: %d", got, want)
+ }
+ body, err := io.ReadAll(response.Body)
+ if err != nil {
+ t.Fatalf("io.ReadAll(response.Body) failed: %s", err)
+ }
+ response.Body.Close()
+ if got, want := string(body), responseString; got != want {
+ t.Fatalf("unexpected response got: %s, want: %s", got, want)
+ }
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index d6d953085..a49f5f87d 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -19,9 +19,7 @@ package sharedmem
import (
"bytes"
- "io/ioutil"
"math/rand"
- "os"
"strings"
"testing"
"time"
@@ -104,24 +102,36 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress
t: t,
packetCh: make(chan struct{}, 1000000),
}
- c.txCfg = createQueueFDs(t, queueSizes{
+ c.txCfg, err = createQueueFDs(queueSizes{
dataSize: queueDataSize,
txPipeSize: queuePipeSize,
rxPipeSize: queuePipeSize,
sharedDataSize: 4096,
})
-
- c.rxCfg = createQueueFDs(t, queueSizes{
+ if err != nil {
+ t.Fatalf("createQueueFDs for tx failed: %s", err)
+ }
+ c.rxCfg, err = createQueueFDs(queueSizes{
dataSize: queueDataSize,
txPipeSize: queuePipeSize,
rxPipeSize: queuePipeSize,
sharedDataSize: 4096,
})
+ if err != nil {
+ t.Fatalf("createQueueFDs for rx failed: %s", err)
+ }
initQueue(t, &c.txq, &c.txCfg)
initQueue(t, &c.rxq, &c.rxCfg)
- ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg)
+ ep, err := New(Options{
+ MTU: mtu,
+ BufferSize: bufferSize,
+ LinkAddress: addr,
+ TX: c.txCfg,
+ RX: c.rxCfg,
+ PeerFD: -1,
+ })
if err != nil {
t.Fatalf("New failed: %v", err)
}
@@ -150,8 +160,8 @@ func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.
func (c *testContext) cleanup() {
c.ep.Close()
- closeFDs(&c.txCfg)
- closeFDs(&c.rxCfg)
+ closeFDs(c.txCfg)
+ closeFDs(c.rxCfg)
c.txq.cleanup()
c.rxq.cleanup()
}
@@ -191,69 +201,6 @@ func shuffle(b []int) {
}
}
-func createFile(t *testing.T, size int64, initQueue bool) int {
- tmpDir, ok := os.LookupEnv("TEST_TMPDIR")
- if !ok {
- tmpDir = os.Getenv("TMPDIR")
- }
- f, err := ioutil.TempFile(tmpDir, "sharedmem_test")
- if err != nil {
- t.Fatalf("TempFile failed: %v", err)
- }
- defer f.Close()
- unix.Unlink(f.Name())
-
- if initQueue {
- // Write the "slot-free" flag in the initial queue.
- _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0)
- if err != nil {
- t.Fatalf("WriteAt failed: %v", err)
- }
- }
-
- fd, err := unix.Dup(int(f.Fd()))
- if err != nil {
- t.Fatalf("Dup failed: %v", err)
- }
-
- if err := unix.Ftruncate(fd, size); err != nil {
- unix.Close(fd)
- t.Fatalf("Ftruncate failed: %v", err)
- }
-
- return fd
-}
-
-func closeFDs(c *QueueConfig) {
- unix.Close(c.DataFD)
- unix.Close(c.EventFD)
- unix.Close(c.TxPipeFD)
- unix.Close(c.RxPipeFD)
- unix.Close(c.SharedDataFD)
-}
-
-type queueSizes struct {
- dataSize int64
- txPipeSize int64
- rxPipeSize int64
- sharedDataSize int64
-}
-
-func createQueueFDs(t *testing.T, s queueSizes) QueueConfig {
- fd, _, err := unix.RawSyscall(unix.SYS_EVENTFD2, 0, 0, 0)
- if err != 0 {
- t.Fatalf("eventfd failed: %v", error(err))
- }
-
- return QueueConfig{
- EventFD: int(fd),
- DataFD: createFile(t, s.dataSize, false),
- TxPipeFD: createFile(t, s.txPipeSize, true),
- RxPipeFD: createFile(t, s.rxPipeSize, true),
- SharedDataFD: createFile(t, s.sharedDataSize, false),
- }
-}
-
// TestSimpleSend sends 1000 packets with random header and payload sizes,
// then checks that the right payload is received on the shared memory queues.
func TestSimpleSend(t *testing.T) {
@@ -263,6 +210,7 @@ func TestSimpleSend(t *testing.T) {
// Prepare route.
var r stack.RouteInfo
r.RemoteLinkAddress = remoteLinkAddr
+ r.LocalLinkAddress = localLinkAddr
for iters := 1000; iters > 0; iters-- {
func() {
@@ -280,8 +228,11 @@ func TestSimpleSend(t *testing.T) {
Data: data.ToVectorisedView(),
})
copy(pkt.NetworkHeader().Push(hdrLen), hdrBuf)
-
proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
+ // Every PacketBuffer must have these set:
+ // See nic.writePacket.
+ pkt.EgressRoute = r
+ pkt.NetworkProtocolNumber = proto
if err := c.ep.WritePacket(r, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -350,8 +301,11 @@ func TestPreserveSrcAddressInSend(t *testing.T) {
// the minimum size of the ethernet header.
ReserveHeaderBytes: header.EthernetMinimumSize,
})
-
proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
+ // Every PacketBuffer must have these set:
+ // See nic.writePacket.
+ pkt.EgressRoute = r
+ pkt.NetworkProtocolNumber = proto
if err := c.ep.WritePacket(r, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -672,7 +626,7 @@ func TestSimpleReceive(t *testing.T) {
// Push completion.
c.pushRxCompletion(uint32(len(contents)), bufs)
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Wait for packet to be received, then check it.
c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet")
@@ -718,7 +672,7 @@ func TestRxBuffersReposted(t *testing.T) {
// Complete the buffer.
c.pushRxCompletion(buffers[i].Size, buffers[i:][:1])
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Wait for it to be reposted.
bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted"))
@@ -734,7 +688,7 @@ func TestRxBuffersReposted(t *testing.T) {
// Complete with two buffers.
c.pushRxCompletion(2*bufferSize, buffers[2*i:][:2])
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Wait for them to be reposted.
for j := 0; j < 2; j++ {
@@ -759,7 +713,7 @@ func TestReceivePostingIsFull(t *testing.T) {
first := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for first buffer to be posted"))
c.pushRxCompletion(first.Size, []queue.RxBuffer{first})
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Check that packet is received.
c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet")
@@ -768,7 +722,7 @@ func TestReceivePostingIsFull(t *testing.T) {
second := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for second buffer to be posted"))
c.pushRxCompletion(second.Size, []queue.RxBuffer{second})
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Check that no packet is received yet, as the worker is blocked trying
// to repost.
@@ -781,7 +735,7 @@ func TestReceivePostingIsFull(t *testing.T) {
// Flush tx queue, which will allow the first buffer to be reposted,
// and the second completion to be pulled.
c.rxq.tx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Check that second packet completes.
c.waitForPackets(1, time.After(time.Second), "Timeout waiting for second completed packet")
@@ -803,7 +757,7 @@ func TestCloseWhileWaitingToPost(t *testing.T) {
bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for initial buffer to be posted"))
c.pushRxCompletion(bi.Size, []queue.RxBuffer{bi})
c.rxq.rx.Flush()
- unix.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ c.rxCfg.EventFD.Notify()
// Wait for packet to be indicated.
c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet")
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go
index f7e816a41..d974c266e 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go
@@ -15,7 +15,12 @@
package sharedmem
import (
+ "fmt"
+ "reflect"
"unsafe"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/memutil"
)
// sharedDataPointer converts the shared data slice into a pointer so that it
@@ -23,3 +28,31 @@ import (
func sharedDataPointer(sharedData []byte) *uint32 {
return (*uint32)(unsafe.Pointer(&sharedData[0:4][0]))
}
+
+// getBuffer returns a memory region mapped to the full contents of the given
+// file descriptor.
+func getBuffer(fd int) ([]byte, error) {
+ var s unix.Stat_t
+ if err := unix.Fstat(fd, &s); err != nil {
+ return nil, err
+ }
+
+ // Check that size doesn't overflow an int.
+ if s.Size > int64(^uint(0)>>1) {
+ return nil, unix.EDOM
+ }
+
+ addr, err := memutil.MapFile(0 /* addr */, uintptr(s.Size), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED|unix.MAP_FILE, uintptr(fd), 0 /*offset*/)
+ if err != nil {
+ return nil, fmt.Errorf("failed to map memory for buffer fd: %d, error: %s", fd, err)
+ }
+
+ // Use unsafe to conver addr into a []byte.
+ var b []byte
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&b))
+ hdr.Data = addr
+ hdr.Len = int(s.Size)
+ hdr.Cap = int(s.Size)
+
+ return b, nil
+}
diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go
index e3210051f..d6c61afee 100644
--- a/pkg/tcpip/link/sharedmem/tx.go
+++ b/pkg/tcpip/link/sharedmem/tx.go
@@ -18,6 +18,7 @@ import (
"math"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/eventfd"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
)
@@ -28,10 +29,12 @@ const (
// tx holds all state associated with a tx queue.
type tx struct {
- data []byte
- q queue.Tx
- ids idManager
- bufs bufferManager
+ data []byte
+ q queue.Tx
+ ids idManager
+ bufs bufferManager
+ eventFD eventfd.Eventfd
+ sharedDataFD int
}
// init initializes all state needed by the tx queue based on the information
@@ -64,7 +67,8 @@ func (t *tx) init(mtu uint32, c *QueueConfig) error {
t.ids.init()
t.bufs.init(0, len(data), int(mtu))
t.data = data
-
+ t.eventFD = c.EventFD
+ t.sharedDataFD = c.SharedDataFD
return nil
}
@@ -142,20 +146,10 @@ func (t *tx) transmit(bufs ...buffer.View) bool {
return true
}
-// getBuffer returns a memory region mapped to the full contents of the given
-// file descriptor.
-func getBuffer(fd int) ([]byte, error) {
- var s unix.Stat_t
- if err := unix.Fstat(fd, &s); err != nil {
- return nil, err
- }
-
- // Check that size doesn't overflow an int.
- if s.Size > int64(^uint(0)>>1) {
- return nil, unix.EDOM
- }
-
- return unix.Mmap(fd, 0, int(s.Size), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED|unix.MAP_FILE)
+// notify writes to the tx.eventFD to indicate to the peer that there is data to
+// be read.
+func (t *tx) notify() {
+ t.eventFD.Notify()
}
// idDescriptor is used by idManager to either point to a tx buffer (in case
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 6515c31e5..e08243547 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -272,7 +272,6 @@ type protocol struct {
func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
func (p *protocol) MinimumPacketSize() int { return header.ARPSize }
-func (p *protocol) DefaultPrefixLen() int { return 0 }
func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) {
return "", ""
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 5fcbfeaa2..061cc35ae 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -153,8 +153,12 @@ func makeTestContext(t *testing.T, eventDepth int, packetDepth int) testContext
t.Fatalf("CreateNIC failed: %s", err)
}
- if err := tc.s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress for ipv4 failed: %s", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: stackAddr.WithPrefix(),
+ }
+ if err := tc.s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
tc.s.SetRouteTable([]tcpip.Route{{
@@ -569,8 +573,12 @@ func TestLinkAddressRequest(t *testing.T) {
}
if len(test.nicAddr) != 0 {
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: test.nicAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 2179302d3..87f650661 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -233,7 +233,13 @@ func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
})
s.CreateNIC(nicID, loopback.New())
- s.AddAddress(nicID, ipv4.ProtocolNumber, local)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: local.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, err
+ }
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
Gateway: ipv4Gateway,
@@ -249,7 +255,13 @@ func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
})
s.CreateNIC(nicID, loopback.New())
- s.AddAddress(nicID, ipv6.ProtocolNumber, local)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: local.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, err
+ }
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv6EmptySubnet,
Gateway: ipv6Gateway,
@@ -272,13 +284,13 @@ func buildDummyStackWithLinkEndpoint(t *testing.T, mtu uint32) (*stack.Stack, *c
}
v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix}
- if err := s.AddProtocolAddress(nicID, v4Addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err)
+ if err := s.AddProtocolAddress(nicID, v4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v4Addr, err)
}
v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix}
- if err := s.AddProtocolAddress(nicID, v6Addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err)
+ if err := s.AddProtocolAddress(nicID, v6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, v6Addr, err)
}
return s, e
@@ -713,8 +725,8 @@ func TestReceive(t *testing.T) {
if !ok {
t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum)
}
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", test.epAddr, err)
} else {
ep.DecRef()
}
@@ -885,8 +897,8 @@ func TestIPv4ReceiveControl(t *testing.T) {
t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
}
addr := localIPv4Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
@@ -971,8 +983,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint")
}
addr := localIPv4Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
@@ -1237,8 +1249,8 @@ func TestIPv6ReceiveControl(t *testing.T) {
t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint")
}
addr := localIPv6Addr.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
@@ -1304,7 +1316,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name string
protoFactory stack.NetworkProtocolFactory
protoNum tcpip.NetworkProtocolNumber
- nicAddr tcpip.Address
+ nicAddr tcpip.AddressWithPrefix
remoteAddr tcpip.Address
pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView
checker func(*testing.T, *stack.PacketBuffer, tcpip.Address)
@@ -1314,7 +1326,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv4MinimumSize + len(data)
@@ -1355,7 +1367,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 with IHL too small",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv4MinimumSize + len(data)
@@ -1379,7 +1391,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 too small",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
@@ -1397,7 +1409,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 minimum size",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
@@ -1433,7 +1445,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 with options",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length())
@@ -1478,7 +1490,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv4 with options and data across views",
protoFactory: ipv4.NewProtocol,
protoNum: ipv4.ProtocolNumber,
- nicAddr: localIPv4Addr,
+ nicAddr: localIPv4AddrWithPrefix,
remoteAddr: remoteIPv4Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length()))
@@ -1519,7 +1531,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv6MinimumSize + len(data)
@@ -1559,7 +1571,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6 with extension header",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data)
@@ -1604,7 +1616,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6 minimum size",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
@@ -1639,7 +1651,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
name: "IPv6 too small",
protoFactory: ipv6.NewProtocol,
protoNum: ipv6.ProtocolNumber,
- nicAddr: localIPv6Addr,
+ nicAddr: localIPv6AddrWithPrefix,
remoteAddr: remoteIPv6Addr,
pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
@@ -1663,11 +1675,11 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
}{
{
name: "unspecified source",
- srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))),
+ srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr.Address))),
},
{
name: "random source",
- srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))),
+ srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr.Address))),
},
}
@@ -1680,15 +1692,19 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: test.protoNum,
+ AddressWithPrefix: test.nicAddr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}})
- r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */)
+ r, err := s.FindRoute(nicID, test.nicAddr.Address, test.remoteAddr, test.protoNum, false /* multicastLoop */)
if err != nil {
- t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err)
+ t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr.Address, test.protoNum, err)
}
defer r.Release()
@@ -2072,8 +2088,12 @@ func TestSetNICIDBeforeDeliveringToRawEndpoint(t *testing.T) {
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddressWithPrefix(nicID, test.proto, test.addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, test.proto, test.addr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: test.proto,
+ AddressWithPrefix: test.addr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 2aa38eb98..3eff0bbd8 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -167,23 +167,22 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet
p := hdr.TransportProtocol()
dstAddr := hdr.DestinationAddress()
// Skip the ip header, then deliver the error.
- pkt.Data().DeleteFront(hlen)
+ if _, ok := pkt.Data().Consume(hlen); !ok {
+ panic(fmt.Sprintf("could not consume the IP header of %d bytes", hlen))
+ }
e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, errInfo, pkt)
}
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
received := e.stats.icmp.packetsReceived
- // ICMP packets don't have their TransportHeader fields set. See
- // icmp/protocol.go:protocol.Parse for a full explanation.
- v, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize)
- if !ok {
+ h := header.ICMPv4(pkt.TransportHeader().View())
+ if len(h) < header.ICMPv4MinimumSize {
received.invalid.Increment()
return
}
- h := header.ICMPv4(v)
// Only do in-stack processing if the checksum is correct.
- if pkt.Data().AsRange().Checksum() != 0xffff {
+ if header.Checksum(h, pkt.Data().AsRange().Checksum()) != 0xffff {
received.invalid.Increment()
// It's possible that a raw socket expects to receive this regardless
// of checksum errors. If it's an echo request we know it's safe because
@@ -240,20 +239,15 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
case header.ICMPv4Echo:
received.echoRequest.Increment()
- sent := e.stats.icmp.packetsSent
- if !e.protocol.stack.AllowICMPMessage() {
- sent.rateLimited.Increment()
- return
- }
-
// DeliverTransportPacket will take ownership of pkt so don't use it beyond
// this point. Make a deep copy of the data before pkt gets sent as we will
- // be modifying fields.
+ // be modifying fields. Both the ICMP header (with its type modified to
+ // EchoReply) and payload are reused in the reply packet.
//
// TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no
// waiting endpoints. Consider moving responsibility for doing the copy to
// DeliverTransportPacket so that is is only done when needed.
- replyData := pkt.Data().AsRange().ToOwnedView()
+ replyData := stack.PayloadSince(pkt.TransportHeader())
ipHdr := header.IPv4(pkt.NetworkHeader().View())
localAddressBroadcast := pkt.NetworkPacketInfo.LocalAddressBroadcast
@@ -281,6 +275,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
}
defer r.Release()
+ sent := e.stats.icmp.packetsSent
+ if !e.protocol.allowICMPReply(header.ICMPv4EchoReply, header.ICMPv4UnusedCode) {
+ sent.rateLimited.Increment()
+ return
+ }
+
// TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the
// header information, we may have to change this code to handle the
// ICMP header no longer being in the data buffer.
@@ -331,6 +331,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
case header.ICMPv4EchoReply:
received.echoReply.Increment()
+ // ICMP sockets expect the ICMP header to be present, so we don't consume
+ // the ICMP header.
e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt)
case header.ICMPv4DstUnreachable:
@@ -338,7 +340,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
mtu := h.MTU()
code := h.Code()
- pkt.Data().DeleteFront(header.ICMPv4MinimumSize)
switch code {
case header.ICMPv4HostUnreachable:
e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt)
@@ -562,31 +563,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
return &tcpip.ErrNotConnected{}
}
- sent := netEP.stats.icmp.packetsSent
-
- if !p.stack.AllowICMPMessage() {
- sent.rateLimited.Increment()
- return nil
- }
-
transportHeader := pkt.TransportHeader().View()
// Don't respond to icmp error packets.
if origIPHdr.Protocol() == uint8(header.ICMPv4ProtocolNumber) {
- // TODO(gvisor.dev/issue/3810):
- // Unfortunately the current stack pretty much always has ICMPv4 headers
- // in the Data section of the packet but there is no guarantee that is the
- // case. If this is the case grab the header to make it like all other
- // packet types. When this is cleaned up the Consume should be removed.
- if transportHeader.IsEmpty() {
- var ok bool
- transportHeader, ok = pkt.TransportHeader().Consume(header.ICMPv4MinimumSize)
- if !ok {
- return nil
- }
- } else if transportHeader.Size() < header.ICMPv4MinimumSize {
- return nil
- }
// We need to decide to explicitly name the packets we can respond to or
// the ones we can not respond to. The decision is somewhat arbitrary and
// if problems arise this could be reversed. It was judged less of a breach
@@ -606,6 +586,35 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
}
}
+ sent := netEP.stats.icmp.packetsSent
+ icmpType, icmpCode, counter, pointer := func() (header.ICMPv4Type, header.ICMPv4Code, tcpip.MultiCounterStat, byte) {
+ switch reason := reason.(type) {
+ case *icmpReasonPortUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4PortUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonProtoUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4ProtoUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonNetworkUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4NetUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonHostUnreachable:
+ return header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonFragmentationNeeded:
+ return header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, sent.dstUnreachable, 0
+ case *icmpReasonTTLExceeded:
+ return header.ICMPv4TimeExceeded, header.ICMPv4TTLExceeded, sent.timeExceeded, 0
+ case *icmpReasonReassemblyTimeout:
+ return header.ICMPv4TimeExceeded, header.ICMPv4ReassemblyTimeout, sent.timeExceeded, 0
+ case *icmpReasonParamProblem:
+ return header.ICMPv4ParamProblem, header.ICMPv4UnusedCode, sent.paramProblem, reason.pointer
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
+ }
+ }()
+
+ if !p.allowICMPReply(icmpType, icmpCode) {
+ sent.rateLimited.Increment()
+ return nil
+ }
+
// Now work out how much of the triggering packet we should return.
// As per RFC 1812 Section 4.3.2.3
//
@@ -658,44 +667,9 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
- var counter tcpip.MultiCounterStat
- switch reason := reason.(type) {
- case *icmpReasonPortUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4PortUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonProtoUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonNetworkUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4NetUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonHostUnreachable:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4HostUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonFragmentationNeeded:
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4FragmentationNeeded)
- counter = sent.dstUnreachable
- case *icmpReasonTTLExceeded:
- icmpHdr.SetType(header.ICMPv4TimeExceeded)
- icmpHdr.SetCode(header.ICMPv4TTLExceeded)
- counter = sent.timeExceeded
- case *icmpReasonReassemblyTimeout:
- icmpHdr.SetType(header.ICMPv4TimeExceeded)
- icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout)
- counter = sent.timeExceeded
- case *icmpReasonParamProblem:
- icmpHdr.SetType(header.ICMPv4ParamProblem)
- icmpHdr.SetCode(header.ICMPv4UnusedCode)
- icmpHdr.SetPointer(reason.pointer)
- counter = sent.paramProblem
- default:
- panic(fmt.Sprintf("unsupported ICMP type %T", reason))
- }
+ icmpHdr.SetCode(icmpCode)
+ icmpHdr.SetType(icmpType)
+ icmpHdr.SetPointer(pointer)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data().AsRange().Checksum()))
if err := route.WritePacket(
diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go
index 4bd6f462e..c6576fcbc 100644
--- a/pkg/tcpip/network/ipv4/igmp_test.go
+++ b/pkg/tcpip/network/ipv4/igmp_test.go
@@ -120,9 +120,12 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma
// cycles.
func TestIGMPV1Present(t *testing.T) {
e, s, clock := createStack(t, true)
- addr := tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength}
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength},
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
@@ -215,8 +218,15 @@ func TestSendQueuedIGMPReports(t *testing.T) {
// The initial set of IGMP reports that were queued should be sent once an
// address is assigned.
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, stackAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: stackAddr,
+ PrefixLen: defaultPrefixLength,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if got := reportStat.Value(); got != 1 {
t.Errorf("got reportStat.Value() = %d, want = 1", got)
@@ -350,8 +360,12 @@ func TestIGMPPacketValidation(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
e, s, _ := createStack(t, true)
for _, address := range test.stackAddresses {
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, address); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, address, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: address,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
stats := s.Stats()
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index e2472c851..d1d509702 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -167,6 +167,13 @@ func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint {
return nil
}
+func (p *protocol) getEndpointForNIC(id tcpip.NICID) (*endpoint, bool) {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ ep, ok := p.mu.eps[id]
+ return ep, ok
+}
+
func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -240,7 +247,7 @@ func (e *endpoint) Enable() tcpip.Error {
}
// Create an endpoint to receive broadcast packets on this interface.
- ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */)
+ ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.AddressProperties{PEB: stack.NeverPrimaryEndpoint})
if err != nil {
return err
}
@@ -419,7 +426,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// iptables filtering. All packets that reach here are locally
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesOutputDropped.Increment()
return nil
@@ -432,7 +439,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// We should do this for every packet, rather than only NATted packets, but
// removing this check short circuits broadcasts before they are sent out to
// other hosts.
- if pkt.NatDone {
+ if pkt.DNATDone {
netHeader := header.IPv4(pkt.NetworkHeader().View())
if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil {
// Since we rewrote the packet but it is being routed back to us, we
@@ -459,7 +466,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn
// Postrouting NAT can only change the source address, and does not alter the
// route or outgoing interface of the packet.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, e, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesPostroutingDropped.Increment()
return nil
@@ -542,7 +549,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
- outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName)
+ outputDropped, natPkts := e.protocol.stack.IPTables().CheckOutputPackets(pkts, r, outNicName)
stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped)))
for pkt := range outputDropped {
pkts.Remove(pkt)
@@ -569,7 +576,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
// We ignore the list of NAT-ed packets here because Postrouting NAT can only
// change the source address, and does not alter the route or outgoing
// interface of the packet.
- postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName)
+ postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, e, outNicName)
stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped)))
for pkt := range postroutingDropped {
pkts.Remove(pkt)
@@ -710,7 +717,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
inNicName := stk.FindNICNameFromID(e.nic.ID())
outNicName := stk.FindNICNameFromID(ep.nic.ID())
- if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesForwardDropped.Increment()
return nil
@@ -737,7 +744,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
inNicName := stk.FindNICNameFromID(e.nic.ID())
outNicName := stk.FindNICNameFromID(r.NICID())
- if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesForwardDropped.Increment()
return nil
@@ -746,7 +753,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
// We need to do a deep copy of the IP packet because
// WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
// not own it.
- newHdr := header.IPv4(stack.PayloadSince(pkt.NetworkHeader()))
+ newPkt := pkt.DeepCopyForForwarding(int(r.MaxHeaderLength()))
+ newHdr := header.IPv4(newPkt.NetworkHeader().View())
// As per RFC 791 page 30, Time to Live,
//
@@ -755,12 +763,19 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
// Even if no local information is available on the time actually
// spent, the field must be decremented by 1.
newHdr.SetTTL(ttl - 1)
+ // We perform a full checksum as we may have updated options above. The IP
+ // header is relatively small so this is not expected to be an expensive
+ // operation.
+ newHdr.SetChecksum(0)
+ newHdr.SetChecksum(^newHdr.CalculateChecksum())
+
+ forwardToEp, ok := e.protocol.getEndpointForNIC(r.NICID())
+ if !ok {
+ // The interface was removed after we obtained the route.
+ return &ip.ErrOther{Err: &tcpip.ErrUnknownDevice{}}
+ }
- switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: buffer.View(newHdr).ToVectorisedView(),
- IsForwardedPacket: true,
- })); err.(type) {
+ switch err := forwardToEp.writePacket(r, newPkt, true /* headerIncluded */); err.(type) {
case nil:
return nil
case *tcpip.ErrMessageTooLong:
@@ -826,7 +841,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Loopback traffic skips the prerouting chain.
inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesPreroutingDropped.Increment()
return
@@ -925,7 +940,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer,
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok {
// iptables is telling us to drop the packet.
stats.ip.IPTablesInputDropped.Increment()
return
@@ -969,7 +984,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer,
}
proto := h.Protocol()
- resPkt, _, ready, err := e.protocol.fragmentation.Process(
+ resPkt, transProtoNum, ready, err := e.protocol.fragmentation.Process(
// As per RFC 791 section 2.3, the identification value is unique
// for a source-destination pair and protocol.
fragmentation.FragmentID{
@@ -1000,6 +1015,8 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer,
h.SetTotalLength(uint16(pkt.Data().Size() + len(h)))
h.SetFlagsFragmentOffset(0, 0)
+ e.protocol.parseTransport(pkt, tcpip.TransportProtocolNumber(transProtoNum))
+
// Now that the packet is reassembled, it can be sent to raw sockets.
e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt)
}
@@ -1075,11 +1092,11 @@ func (e *endpoint) Close() {
}
// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+ ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, properties)
if err == nil {
e.mu.igmp.sendQueuedReports()
}
@@ -1200,6 +1217,9 @@ type protocol struct {
// eps is keyed by NICID to allow protocol methods to retrieve an endpoint
// when handling a packet, by looking at which NIC handled the packet.
eps map[tcpip.NICID]*endpoint
+
+ // ICMP types for which the stack's global rate limiting must apply.
+ icmpRateLimitedTypes map[header.ICMPv4Type]struct{}
}
// defaultTTL is the current default TTL for the protocol. Only the
@@ -1226,11 +1246,6 @@ func (p *protocol) MinimumPacketSize() int {
return header.IPv4MinimumSize
}
-// DefaultPrefixLen returns the IPv4 default prefix length.
-func (p *protocol) DefaultPrefixLen() int {
- return header.IPv4AddressSize * 8
-}
-
// ParseAddresses implements stack.NetworkProtocol.
func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv4(v)
@@ -1297,19 +1312,29 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv4, bool)
}
if hasTransportHdr {
- switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err {
- case stack.ParsedOK:
- case stack.UnknownTransportProtocol, stack.TransportLayerParseError:
- // The transport layer will handle unknown protocols and transport layer
- // parsing errors.
- default:
- panic(fmt.Sprintf("unexpected error parsing transport header = %d", err))
- }
+ p.parseTransport(pkt, transProtoNum)
}
return h, true
}
+func (p *protocol) parseTransport(pkt *stack.PacketBuffer, transProtoNum tcpip.TransportProtocolNumber) {
+ if transProtoNum == header.ICMPv4ProtocolNumber {
+ // The transport layer will handle transport layer parsing errors.
+ _ = parse.ICMPv4(pkt)
+ return
+ }
+
+ switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err {
+ case stack.ParsedOK:
+ case stack.UnknownTransportProtocol, stack.TransportLayerParseError:
+ // The transport layer will handle unknown protocols and transport layer
+ // parsing errors.
+ default:
+ panic(fmt.Sprintf("unexpected error parsing transport header = %d", err))
+ }
+}
+
// Parse implements stack.NetworkProtocol.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
if ok := parse.IPv4(pkt); !ok {
@@ -1320,6 +1345,23 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true
}
+// allowICMPReply reports whether an ICMP reply with provided type and code may
+// be sent following the rate mask options and global ICMP rate limiter.
+func (p *protocol) allowICMPReply(icmpType header.ICMPv4Type, code header.ICMPv4Code) bool {
+ // Mimic linux and never rate limit for PMTU discovery.
+ // https://github.com/torvalds/linux/blob/9e9fb7655ed585da8f468e29221f0ba194a5f613/net/ipv4/icmp.c#L288
+ if icmpType == header.ICMPv4DstUnreachable && code == header.ICMPv4FragmentationNeeded {
+ return true
+ }
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+
+ if _, ok := p.mu.icmpRateLimitedTypes[icmpType]; ok {
+ return p.stack.AllowICMPMessage()
+ }
+ return true
+}
+
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload mtu.
func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) {
@@ -1399,6 +1441,14 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
}
p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
p.mu.eps = make(map[tcpip.NICID]*endpoint)
+ // Set ICMP rate limiting to Linux defaults.
+ // See https://man7.org/linux/man-pages/man7/icmp.7.html.
+ p.mu.icmpRateLimitedTypes = map[header.ICMPv4Type]struct{}{
+ header.ICMPv4DstUnreachable: struct{}{},
+ header.ICMPv4SrcQuench: struct{}{},
+ header.ICMPv4TimeExceeded: struct{}{},
+ header.ICMPv4ParamProblem: struct{}{},
+ }
return p
}
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 73407be67..ef91245d7 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -101,8 +101,12 @@ func TestExcludeBroadcast(t *testing.T) {
defer ep.Close()
// Add a valid primary endpoint address, now we can connect.
- if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address("\x0a\x00\x00\x02").WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
if err := ep.Connect(randomAddr); err != nil {
t.Errorf("Connect failed: %v", err)
@@ -356,8 +360,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
}
incomingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: incomingIPv4Addr}
- if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingIPv4ProtoAddr, err)
}
expectedEmittedPacketCount := 1
@@ -369,8 +373,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
}
outgoingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: outgoingIPv4Addr}
- if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingIPv4ProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -1184,8 +1188,8 @@ func TestIPv4Sanity(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
- if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err)
}
// Default routes for IPv4 so ICMP can find a route to the remote
@@ -1745,8 +1749,8 @@ func TestInvalidFragments(t *testing.T) {
const (
nicID = 1
linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- addr1 = "\x0a\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x02")
tos = 0
ident = 1
ttl = 48
@@ -2012,8 +2016,12 @@ func TestInvalidFragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
for _, f := range test.fragments {
@@ -2061,8 +2069,8 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
const (
nicID = 1
linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- addr1 = "\x0a\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x02")
tos = 0
ident = 1
ttl = 48
@@ -2237,8 +2245,12 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
@@ -2308,9 +2320,9 @@ func TestReceiveFragments(t *testing.T) {
const (
nicID = 1
- addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1
- addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2
- addr3 = "\x0c\xa8\x00\x03" // 192.168.0.3
+ addr1 = tcpip.Address("\x0c\xa8\x00\x01") // 192.168.0.1
+ addr2 = tcpip.Address("\x0c\xa8\x00\x02") // 192.168.0.2
+ addr3 = tcpip.Address("\x0c\xa8\x00\x03") // 192.168.0.3
)
// Build and return a UDP header containing payload.
@@ -2703,8 +2715,12 @@ func TestReceiveFragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
wq := waiter.Queue{}
@@ -2985,11 +3001,15 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
t.Fatalf("CreateNIC(1, _) failed: %s", err)
}
const (
- src = "\x10\x00\x00\x01"
- dst = "\x10\x00\x00\x02"
+ src = tcpip.Address("\x10\x00\x00\x01")
+ dst = tcpip.Address("\x10\x00\x00\x02")
)
- if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: src.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
mask := tcpip.AddressMask(header.IPv4Broadcast)
@@ -3161,8 +3181,8 @@ func TestPacketQueuing(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err)
+ if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -3285,8 +3305,12 @@ func TestCloseLocking(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddAddress(nicID1, ipv4.ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) failed: %s", nicID1, ipv4.ProtocolNumber, src, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: src.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
@@ -3349,3 +3373,139 @@ func TestCloseLocking(t *testing.T) {
}
}()
}
+
+func TestIcmpRateLimit(t *testing.T) {
+ var (
+ host1IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
+ PrefixLen: 24,
+ },
+ }
+ host2IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
+ PrefixLen: 24,
+ },
+ }
+ )
+ const icmpBurst = 5
+ e := channel.New(1, defaultMTU, tcpip.LinkAddress(""))
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: faketime.NewManualClock(),
+ })
+ s.SetICMPBurst(icmpBurst)
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: nicID,
+ },
+ })
+ tests := []struct {
+ name string
+ createPacket func() buffer.View
+ check func(*testing.T, *channel.Endpoint, int)
+ }{
+ {
+ name: "echo",
+ createPacket: func() buffer.View {
+ totalLength := header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ hdr := buffer.NewPrependable(totalLength)
+ icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ icmpH.SetIdent(1)
+ icmpH.SetSequence(1)
+ icmpH.SetType(header.ICMPv4Echo)
+ icmpH.SetCode(header.ICMPv4UnusedCode)
+ icmpH.SetChecksum(0)
+ icmpH.SetChecksum(^header.Checksum(icmpH, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(totalLength),
+ Protocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: 1,
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+ return hdr.View()
+ },
+ check: func(t *testing.T, e *channel.Endpoint, round int) {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected echo response, no packet read in endpoint in round %d", round)
+ }
+ if got, want := p.Proto, header.IPv4ProtocolNumber; got != want {
+ t.Errorf("got p.Proto = %d, want = %d", got, want)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4EchoReply),
+ ))
+ },
+ },
+ {
+ name: "dst unreachable",
+ createPacket: func() buffer.View {
+ totalLength := header.IPv4MinimumSize + header.UDPMinimumSize
+ hdr := buffer.NewPrependable(totalLength)
+ udpH := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ udpH.Encode(&header.UDPFields{
+ SrcPort: 100,
+ DstPort: 101,
+ Length: header.UDPMinimumSize,
+ })
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(totalLength),
+ Protocol: uint8(header.UDPProtocolNumber),
+ TTL: 1,
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+ return hdr.View()
+ },
+ check: func(t *testing.T, e *channel.Endpoint, round int) {
+ p, ok := e.Read()
+ if round >= icmpBurst {
+ if ok {
+ t.Errorf("got packet %x in round %d, expected ICMP rate limit to stop it", p.Pkt.Data().Views(), round)
+ }
+ return
+ }
+ if !ok {
+ t.Fatalf("expected unreachable in round %d, no packet read in endpoint", round)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ ))
+ },
+ },
+ }
+ for _, testCase := range tests {
+ t.Run(testCase.name, func(t *testing.T) {
+ for round := 0; round < icmpBurst+1; round++ {
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: testCase.createPacket().ToVectorisedView(),
+ }))
+ testCase.check(t, e, round)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index f99cbf8f3..f814926a3 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -51,6 +51,7 @@ go_test(
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
+ "@org_golang_x_time//rate:go_default_library",
],
)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 94caaae6c..adfc8d8da 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -187,7 +187,9 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
// Skip the IP header, then handle the fragmentation header if there
// is one.
- pkt.Data().DeleteFront(header.IPv6MinimumSize)
+ if _, ok := pkt.Data().Consume(header.IPv6MinimumSize); !ok {
+ panic("could not consume IPv6MinimumSize bytes")
+ }
if p == header.IPv6FragmentHeader {
f, ok := pkt.Data().PullUp(header.IPv6FragmentHeaderSize)
if !ok {
@@ -203,7 +205,9 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe
// Skip fragmentation header and find out the actual protocol
// number.
- pkt.Data().DeleteFront(header.IPv6FragmentHeaderSize)
+ if _, ok := pkt.Data().Consume(header.IPv6FragmentHeaderSize); !ok {
+ panic("could not consume IPv6FragmentHeaderSize bytes")
+ }
}
e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, transErr, pkt)
@@ -270,7 +274,7 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP
if routerAlert == nil || routerAlert.Value != header.IPv6RouterAlertMLD {
return false
}
- if pkt.Data().Size() < header.ICMPv6HeaderSize+header.MLDMinimumSize {
+ if pkt.TransportHeader().View().Size() < header.ICMPv6HeaderSize+header.MLDMinimumSize {
return false
}
if iph.HopLimit() != header.MLDHopLimit {
@@ -285,20 +289,17 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, routerAlert *header.IPv6RouterAlertOption) {
sent := e.stats.icmp.packetsSent
received := e.stats.icmp.packetsReceived
- // ICMP packets don't have their TransportHeader fields set. See
- // icmp/protocol.go:protocol.Parse for a full explanation.
- v, ok := pkt.Data().PullUp(header.ICMPv6HeaderSize)
- if !ok {
+ h := header.ICMPv6(pkt.TransportHeader().View())
+ if len(h) < header.ICMPv6MinimumSize {
received.invalid.Increment()
return
}
- h := header.ICMPv6(v)
iph := header.IPv6(pkt.NetworkHeader().View())
srcAddr := iph.SourceAddress()
dstAddr := iph.DestinationAddress()
// Validate ICMPv6 checksum before processing the packet.
- payload := pkt.Data().AsRange().SubRange(len(h))
+ payload := pkt.Data().AsRange()
if got, want := h.Checksum(), header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: h,
Src: srcAddr,
@@ -325,28 +326,15 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
switch icmpType := h.Type(); icmpType {
case header.ICMPv6PacketTooBig:
received.packetTooBig.Increment()
- hdr, ok := pkt.Data().PullUp(header.ICMPv6PacketTooBigMinimumSize)
- if !ok {
- received.invalid.Increment()
- return
- }
- networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize)
+ networkMTU, err := calculateNetworkMTU(h.MTU(), header.IPv6MinimumSize)
if err != nil {
networkMTU = 0
}
- pkt.Data().DeleteFront(header.ICMPv6PacketTooBigMinimumSize)
e.handleControl(&icmpv6PacketTooBigSockError{mtu: networkMTU}, pkt)
case header.ICMPv6DstUnreachable:
received.dstUnreachable.Increment()
- hdr, ok := pkt.Data().PullUp(header.ICMPv6DstUnreachableMinimumSize)
- if !ok {
- received.invalid.Increment()
- return
- }
- code := header.ICMPv6(hdr).Code()
- pkt.Data().DeleteFront(header.ICMPv6DstUnreachableMinimumSize)
- switch code {
+ switch h.Code() {
case header.ICMPv6NetworkUnreachable:
e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt)
case header.ICMPv6PortUnreachable:
@@ -354,16 +342,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
}
case header.ICMPv6NeighborSolicit:
received.neighborSolicit.Increment()
- if !isNDPValid() || pkt.Data().Size() < header.ICMPv6NeighborSolicitMinimumSize {
+ if !isNDPValid() || len(h) < header.ICMPv6NeighborSolicitMinimumSize {
received.invalid.Increment()
return
}
- // The remainder of payload must be only the neighbor solicitation, so
- // payload.AsView() always returns the solicitation. Per RFC 6980 section 5,
- // NDP messages cannot be fragmented. Also note that in the common case NDP
- // datagrams are very small and AsView() will not incur allocations.
- ns := header.NDPNeighborSolicit(payload.AsView())
+ ns := header.NDPNeighborSolicit(h.MessageBody())
targetAddr := ns.TargetAddress()
// As per RFC 4861 section 4.3, the Target Address MUST NOT be a multicast
@@ -576,16 +560,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
case header.ICMPv6NeighborAdvert:
received.neighborAdvert.Increment()
- if !isNDPValid() || pkt.Data().Size() < header.ICMPv6NeighborAdvertMinimumSize {
+ if !isNDPValid() || len(h) < header.ICMPv6NeighborAdvertMinimumSize {
received.invalid.Increment()
return
}
- // The remainder of payload must be only the neighbor advertisement, so
- // payload.AsView() always returns the advertisement. Per RFC 6980 section
- // 5, NDP messages cannot be fragmented. Also note that in the common case
- // NDP datagrams are very small and AsView() will not incur allocations.
- na := header.NDPNeighborAdvert(payload.AsView())
+ na := header.NDPNeighborAdvert(h.MessageBody())
it, err := na.Options().Iter(false /* check */)
if err != nil {
@@ -672,12 +652,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
case header.ICMPv6EchoRequest:
received.echoRequest.Increment()
- icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize)
- if !ok {
- received.invalid.Increment()
- return
- }
-
// As per RFC 4291 section 2.7, multicast addresses must not be used as
// source addresses in IPv6 packets.
localAddr := dstAddr
@@ -692,13 +666,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
}
defer r.Release()
+ if !e.protocol.allowICMPReply(header.ICMPv6EchoReply) {
+ sent.rateLimited.Increment()
+ return
+ }
+
replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize,
Data: pkt.Data().ExtractVV(),
})
icmp := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize))
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
- copy(icmp, icmpHdr)
+ copy(icmp, h)
icmp.SetType(header.ICMPv6EchoReply)
dataRange := replyPkt.Data().AsRange()
icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
@@ -720,7 +699,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
case header.ICMPv6EchoReply:
received.echoReply.Increment()
- if pkt.Data().Size() < header.ICMPv6EchoMinimumSize {
+ if len(h) < header.ICMPv6EchoMinimumSize {
received.invalid.Increment()
return
}
@@ -740,7 +719,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
//
// Is the NDP payload of sufficient size to hold a Router Solictation?
- if !isNDPValid() || pkt.Data().Size()-header.ICMPv6HeaderSize < header.NDPRSMinimumSize {
+ if !isNDPValid() || len(h)-header.ICMPv6HeaderSize < header.NDPRSMinimumSize {
received.invalid.Increment()
return
}
@@ -750,9 +729,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
return
}
- // Note that in the common case NDP datagrams are very small and AsView()
- // will not incur allocations.
- rs := header.NDPRouterSolicit(payload.AsView())
+ rs := header.NDPRouterSolicit(h.MessageBody())
it, err := rs.Options().Iter(false /* check */)
if err != nil {
// Options are not valid as per the wire format, silently drop the packet.
@@ -796,7 +773,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
//
// Is the NDP payload of sufficient size to hold a Router Advertisement?
- if !isNDPValid() || pkt.Data().Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize {
+ if !isNDPValid() || len(h)-header.ICMPv6HeaderSize < header.NDPRAMinimumSize {
received.invalid.Increment()
return
}
@@ -810,9 +787,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
return
}
- // Note that in the common case NDP datagrams are very small and AsView()
- // will not incur allocations.
- ra := header.NDPRouterAdvert(payload.AsView())
+ ra := header.NDPRouterAdvert(h.MessageBody())
it, err := ra.Options().Iter(false /* check */)
if err != nil {
// Options are not valid as per the wire format, silently drop the packet.
@@ -890,11 +865,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
switch icmpType {
case header.ICMPv6MulticastListenerQuery:
e.mu.Lock()
- e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.AsView()))
+ e.mu.mld.handleMulticastListenerQuery(header.MLD(h.MessageBody()))
e.mu.Unlock()
case header.ICMPv6MulticastListenerReport:
e.mu.Lock()
- e.mu.mld.handleMulticastListenerReport(header.MLD(payload.AsView()))
+ e.mu.mld.handleMulticastListenerReport(header.MLD(h.MessageBody()))
e.mu.Unlock()
case header.ICMPv6MulticastListenerDone:
default:
@@ -1174,28 +1149,37 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
return &tcpip.ErrNotConnected{}
}
- sent := netEP.stats.icmp.packetsSent
-
- if !p.stack.AllowICMPMessage() {
- sent.rateLimited.Increment()
- return nil
- }
-
if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber {
- // TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored.
- // Unfortunately at this time ICMP Packets do not have a transport
- // header separated out. It is in the Data part so we need to
- // separate it out now. We will just pretend it is a minimal length
- // ICMP packet as we don't really care if any later bits of a
- // larger ICMP packet are in the header view or in the Data view.
- transport, ok := pkt.TransportHeader().Consume(header.ICMPv6MinimumSize)
- if !ok {
+ if typ := header.ICMPv6(pkt.TransportHeader().View()).Type(); typ.IsErrorType() || typ == header.ICMPv6RedirectMsg {
return nil
}
- typ := header.ICMPv6(transport).Type()
- if typ.IsErrorType() || typ == header.ICMPv6RedirectMsg {
- return nil
+ }
+
+ sent := netEP.stats.icmp.packetsSent
+ icmpType, icmpCode, counter, typeSpecific := func() (header.ICMPv6Type, header.ICMPv6Code, tcpip.MultiCounterStat, uint32) {
+ switch reason := reason.(type) {
+ case *icmpReasonParameterProblem:
+ return header.ICMPv6ParamProblem, reason.code, sent.paramProblem, reason.pointer
+ case *icmpReasonPortUnreachable:
+ return header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonNetUnreachable:
+ return header.ICMPv6DstUnreachable, header.ICMPv6NetworkUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonHostUnreachable:
+ return header.ICMPv6DstUnreachable, header.ICMPv6AddressUnreachable, sent.dstUnreachable, 0
+ case *icmpReasonPacketTooBig:
+ return header.ICMPv6PacketTooBig, header.ICMPv6UnusedCode, sent.packetTooBig, 0
+ case *icmpReasonHopLimitExceeded:
+ return header.ICMPv6TimeExceeded, header.ICMPv6HopLimitExceeded, sent.timeExceeded, 0
+ case *icmpReasonReassemblyTimeout:
+ return header.ICMPv6TimeExceeded, header.ICMPv6ReassemblyTimeout, sent.timeExceeded, 0
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
}
+ }()
+
+ if !p.allowICMPReply(icmpType) {
+ sent.rateLimited.Increment()
+ return nil
}
network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
@@ -1232,40 +1216,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip
newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize))
- var counter tcpip.MultiCounterStat
- switch reason := reason.(type) {
- case *icmpReasonParameterProblem:
- icmpHdr.SetType(header.ICMPv6ParamProblem)
- icmpHdr.SetCode(reason.code)
- icmpHdr.SetTypeSpecific(reason.pointer)
- counter = sent.paramProblem
- case *icmpReasonPortUnreachable:
- icmpHdr.SetType(header.ICMPv6DstUnreachable)
- icmpHdr.SetCode(header.ICMPv6PortUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonNetUnreachable:
- icmpHdr.SetType(header.ICMPv6DstUnreachable)
- icmpHdr.SetCode(header.ICMPv6NetworkUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonHostUnreachable:
- icmpHdr.SetType(header.ICMPv6DstUnreachable)
- icmpHdr.SetCode(header.ICMPv6AddressUnreachable)
- counter = sent.dstUnreachable
- case *icmpReasonPacketTooBig:
- icmpHdr.SetType(header.ICMPv6PacketTooBig)
- icmpHdr.SetCode(header.ICMPv6UnusedCode)
- counter = sent.packetTooBig
- case *icmpReasonHopLimitExceeded:
- icmpHdr.SetType(header.ICMPv6TimeExceeded)
- icmpHdr.SetCode(header.ICMPv6HopLimitExceeded)
- counter = sent.timeExceeded
- case *icmpReasonReassemblyTimeout:
- icmpHdr.SetType(header.ICMPv6TimeExceeded)
- icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout)
- counter = sent.timeExceeded
- default:
- panic(fmt.Sprintf("unsupported ICMP type %T", reason))
- }
+ icmpHdr.SetType(icmpType)
+ icmpHdr.SetCode(icmpCode)
+ icmpHdr.SetTypeSpecific(typeSpecific)
+
dataRange := newPkt.Data().AsRange()
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 7c2a3e56b..03d9f425c 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -22,6 +22,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
+ "golang.org/x/time/rate"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -225,8 +226,8 @@ func TestICMPCounts(t *testing.T) {
t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
}
addr := lladdr0.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
@@ -407,8 +408,12 @@ func newTestContext(t *testing.T) *testContext {
if err := c.s0.CreateNIC(nicID, wrappedEP0); err != nil {
t.Fatalf("CreateNIC s0: %v", err)
}
- if err := c.s0.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress lladdr0: %v", err)
+ llProtocolAddr0 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := c.s0.AddProtocolAddress(nicID, llProtocolAddr0, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, llProtocolAddr0, err)
}
c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1)
@@ -416,8 +421,12 @@ func newTestContext(t *testing.T) *testContext {
if err := c.s1.CreateNIC(nicID, wrappedEP1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := c.s1.AddAddress(nicID, ProtocolNumber, lladdr1); err != nil {
- t.Fatalf("AddAddress lladdr1: %v", err)
+ llProtocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr1.WithPrefix(),
+ }
+ if err := c.s1.AddProtocolAddress(nicID, llProtocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, llProtocolAddr1, err)
}
subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -690,8 +699,12 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -883,8 +896,12 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -1065,8 +1082,12 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
@@ -1240,8 +1261,12 @@ func TestLinkAddressRequest(t *testing.T) {
}
if len(test.nicAddr) != 0 {
- if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: test.nicAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
@@ -1411,12 +1436,14 @@ func TestPacketQueing(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
Clock: clock,
})
+ // Make sure ICMP rate limiting doesn't get in our way.
+ s.SetICMPLimit(rate.Inf)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err)
+ if err := s.AddProtocolAddress(nicID, host1IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv6Addr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -1669,8 +1696,12 @@ func TestCallsToNeighborCache(t *testing.T) {
if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_, _) = %s", err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
{
@@ -1704,8 +1735,8 @@ func TestCallsToNeighborCache(t *testing.T) {
t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint")
}
addr := lladdr0.WithPrefix()
- if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil {
- t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err)
+ if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, {}): %s", addr, err)
} else {
ep.DecRef()
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index d4bd61748..7d3e1fd53 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -748,7 +748,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// iptables filtering. All packets that reach here are locally
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesOutputDropped.Increment()
return nil
@@ -761,7 +761,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// We should do this for every packet, rather than only NATted packets, but
// removing this check short circuits broadcasts before they are sent out to
// other hosts.
- if pkt.NatDone {
+ if pkt.DNATDone {
netHeader := header.IPv6(pkt.NetworkHeader().View())
if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil {
// Since we rewrote the packet but it is being routed back to us, we
@@ -788,7 +788,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol
// Postrouting NAT can only change the source address, and does not alter the
// route or outgoing interface of the packet.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
+ if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, e, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesPostroutingDropped.Increment()
return nil
@@ -871,7 +871,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
// iptables filtering. All packets that reach here are locally
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName)
+ outputDropped, natPkts := e.protocol.stack.IPTables().CheckOutputPackets(pkts, r, outNicName)
stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped)))
for pkt := range outputDropped {
pkts.Remove(pkt)
@@ -897,7 +897,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par
// We ignore the list of NAT-ed packets here because Postrouting NAT can only
// change the source address, and does not alter the route or outgoing
// interface of the packet.
- postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName)
+ postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, e, outNicName)
stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped)))
for pkt := range postroutingDropped {
pkts.Remove(pkt)
@@ -984,7 +984,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
inNicName := stk.FindNICNameFromID(e.nic.ID())
outNicName := stk.FindNICNameFromID(ep.nic.ID())
- if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesForwardDropped.Increment()
return nil
@@ -1015,7 +1015,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
inNicName := stk.FindNICNameFromID(e.nic.ID())
outNicName := stk.FindNICNameFromID(r.NICID())
- if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok {
+ if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesForwardDropped.Increment()
return nil
@@ -1024,7 +1024,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
// We need to do a deep copy of the IP packet because
// WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do
// not own it.
- newHdr := header.IPv6(stack.PayloadSince(pkt.NetworkHeader()))
+ newPkt := pkt.DeepCopyForForwarding(int(r.MaxHeaderLength()))
+ newHdr := header.IPv6(newPkt.NetworkHeader().View())
// As per RFC 8200 section 3,
//
@@ -1032,11 +1033,13 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
// each node that forwards the packet.
newHdr.SetHopLimit(hopLimit - 1)
- switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: buffer.View(newHdr).ToVectorisedView(),
- IsForwardedPacket: true,
- })); err.(type) {
+ forwardToEp, ok := e.protocol.getEndpointForNIC(r.NICID())
+ if !ok {
+ // The interface was removed after we obtained the route.
+ return &ip.ErrOther{Err: &tcpip.ErrUnknownDevice{}}
+ }
+
+ switch err := forwardToEp.writePacket(r, newPkt, newPkt.TransportProtocolNumber, true /* headerIncluded */); err.(type) {
case nil:
return nil
case *tcpip.ErrMessageTooLong:
@@ -1097,7 +1100,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Loopback traffic skips the prerouting chain.
inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesPreroutingDropped.Increment()
return
@@ -1180,7 +1183,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer,
// iptables filtering. All packets that reach here are intended for
// this machine and need not be forwarded.
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesInputDropped.Increment()
return
@@ -1534,27 +1537,36 @@ func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffe
// If the last header in the payload isn't a known IPv6 extension header,
// handle it as if it is transport layer data.
- // Calculate the number of octets parsed from data. We want to remove all
- // the data except the unparsed portion located at the end, which its size
- // is extHdr.Buf.Size().
+ // Calculate the number of octets parsed from data. We want to consume all
+ // the data except the unparsed portion located at the end, whose size is
+ // extHdr.Buf.Size().
trim := pkt.Data().Size() - extHdr.Buf.Size()
// For unfragmented packets, extHdr still contains the transport header.
- // Get rid of it.
+ // Consume that too.
//
// For reassembled fragments, pkt.TransportHeader is unset, so this is a
// no-op and pkt.Data begins with the transport header.
trim += pkt.TransportHeader().View().Size()
- pkt.Data().DeleteFront(trim)
+ if _, ok := pkt.Data().Consume(trim); !ok {
+ stats.MalformedPacketsReceived.Increment()
+ return fmt.Errorf("could not consume %d bytes", trim)
+ }
+
+ proto := tcpip.TransportProtocolNumber(extHdr.Identifier)
+ // If the packet was reassembled from a fragment, it will not have a
+ // transport header set yet.
+ if pkt.TransportHeader().View().IsEmpty() {
+ e.protocol.parseTransport(pkt, proto)
+ }
stats.PacketsDelivered.Increment()
- if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
- pkt.TransportProtocolNumber = p
+ if proto == header.ICMPv6ProtocolNumber {
e.handleICMP(pkt, hasFragmentHeader, routerAlert)
} else {
stats.PacketsDelivered.Increment()
- switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res {
+ switch res := e.dispatcher.DeliverTransportPacket(proto, pkt); res {
case stack.TransportPacketHandled:
case stack.TransportPacketDestinationPortUnreachable:
// As per RFC 4443 section 3.1:
@@ -1628,12 +1640,12 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
}
// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) {
// TODO(b/169350103): add checks here after making sure we no longer receive
// an empty address.
e.mu.Lock()
defer e.mu.Unlock()
- return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated)
+ return e.addAndAcquirePermanentAddressLocked(addr, properties)
}
// addAndAcquirePermanentAddressLocked is like AddAndAcquirePermanentAddress but
@@ -1643,8 +1655,8 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p
// solicited-node multicast group and start duplicate address detection.
//
// Precondition: e.mu must be write locked.
-func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
- addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, properties stack.AddressProperties) (stack.AddressEndpoint, tcpip.Error) {
+ addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, properties)
if err != nil {
return nil, err
}
@@ -1987,6 +1999,9 @@ type protocol struct {
// eps is keyed by NICID to allow protocol methods to retrieve an endpoint
// when handling a packet, by looking at which NIC handled the packet.
eps map[tcpip.NICID]*endpoint
+
+ // ICMP types for which the stack's global rate limiting must apply.
+ icmpRateLimitedTypes map[header.ICMPv6Type]struct{}
}
ids []uint32
@@ -1998,7 +2013,8 @@ type protocol struct {
// Must be accessed using atomic operations.
defaultTTL uint32
- fragmentation *fragmentation.Fragmentation
+ fragmentation *fragmentation.Fragmentation
+ icmpRateLimiter *stack.ICMPRateLimiter
}
// Number returns the ipv6 protocol number.
@@ -2011,11 +2027,6 @@ func (p *protocol) MinimumPacketSize() int {
return header.IPv6MinimumSize
}
-// DefaultPrefixLen returns the IPv6 default prefix length.
-func (p *protocol) DefaultPrefixLen() int {
- return header.IPv6AddressSize * 8
-}
-
// ParseAddresses implements stack.NetworkProtocol.
func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv6(v)
@@ -2087,6 +2098,13 @@ func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint {
return nil
}
+func (p *protocol) getEndpointForNIC(id tcpip.NICID) (*endpoint, bool) {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ ep, ok := p.mu.eps[id]
+ return ep, ok
+}
+
func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -2149,19 +2167,29 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv6, bool)
}
if hasTransportHdr {
- switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err {
- case stack.ParsedOK:
- case stack.UnknownTransportProtocol, stack.TransportLayerParseError:
- // The transport layer will handle unknown protocols and transport layer
- // parsing errors.
- default:
- panic(fmt.Sprintf("unexpected error parsing transport header = %d", err))
- }
+ p.parseTransport(pkt, transProtoNum)
}
return h, true
}
+func (p *protocol) parseTransport(pkt *stack.PacketBuffer, transProtoNum tcpip.TransportProtocolNumber) {
+ if transProtoNum == header.ICMPv6ProtocolNumber {
+ // The transport layer will handle transport layer parsing errors.
+ _ = parse.ICMPv6(pkt)
+ return
+ }
+
+ switch err := p.stack.ParsePacketBufferTransport(transProtoNum, pkt); err {
+ case stack.ParsedOK:
+ case stack.UnknownTransportProtocol, stack.TransportLayerParseError:
+ // The transport layer will handle unknown protocols and transport layer
+ // parsing errors.
+ default:
+ panic(fmt.Sprintf("unexpected error parsing transport header = %d", err))
+ }
+}
+
// Parse implements stack.NetworkProtocol.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt)
@@ -2172,6 +2200,18 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return proto, !fragMore && fragOffset == 0, true
}
+// allowICMPReply reports whether an ICMP reply with provided type may
+// be sent following the rate mask options and global ICMP rate limiter.
+func (p *protocol) allowICMPReply(icmpType header.ICMPv6Type) bool {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+
+ if _, ok := p.mu.icmpRateLimitedTypes[icmpType]; ok {
+ return p.stack.AllowICMPMessage()
+ }
+ return true
+}
+
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload MTU and the length of every IPv6 header.
// Note that this is different than the Payload Length field of the IPv6 header,
@@ -2268,6 +2308,21 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
p.mu.eps = make(map[tcpip.NICID]*endpoint)
p.SetDefaultTTL(DefaultTTL)
+ // Set default ICMP rate limiting to Linux defaults.
+ //
+ // Default: 0-1,3-127 (rate limit ICMPv6 errors except Packet Too Big)
+ // See https://www.kernel.org/doc/Documentation/networking/ip-sysctl.txt.
+ defaultIcmpTypes := make(map[header.ICMPv6Type]struct{})
+ for i := header.ICMPv6Type(0); i < header.ICMPv6EchoRequest; i++ {
+ switch i {
+ case header.ICMPv6PacketTooBig:
+ // Do not rate limit packet too big by default.
+ default:
+ defaultIcmpTypes[i] = struct{}{}
+ }
+ }
+ p.mu.icmpRateLimitedTypes = defaultIcmpTypes
+
return p
}
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index d2a23fd4f..e5286081e 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -41,12 +41,12 @@ import (
)
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
// The least significant 3 bytes are the same as addr2 so both addr2 and
// addr3 will have the same solicited-node address.
- addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02"
- addr4 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03"
+ addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02")
+ addr4 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03")
// Tests use the extension header identifier values as uint8 instead of
// header.IPv6ExtensionHeaderIdentifier.
@@ -298,16 +298,24 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
// addr2/addr3 yet as we haven't added those addresses.
test.rxf(t, s, e, addr1, snmc, 0)
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr2, err)
}
// Should receive a packet destined to the solicited node address of
// addr2/addr3 now that we have added added addr2.
test.rxf(t, s, e, addr1, snmc, 1)
- if err := s.AddAddress(nicID, ProtocolNumber, addr3); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr3, err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr3.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr3, err)
}
// Should still receive a packet destined to the solicited node address of
@@ -374,8 +382,12 @@ func TestAddIpv6Address(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, test.addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, nil) = %s", nicID, ProtocolNumber, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: test.addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if addr, err := s.GetMainNICAddress(nicID, ProtocolNumber); err != nil {
@@ -898,8 +910,12 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// Add a default route so that a return packet knows where to go.
@@ -1992,8 +2008,12 @@ func TestReceiveIPv6Fragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
wq := waiter.Queue{}
@@ -2060,8 +2080,8 @@ func TestReceiveIPv6Fragments(t *testing.T) {
func TestInvalidIPv6Fragments(t *testing.T) {
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
nicID = 1
hoplimit = 255
@@ -2150,8 +2170,12 @@ func TestInvalidIPv6Fragments(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv6EmptySubnet,
@@ -2216,8 +2240,8 @@ func TestInvalidIPv6Fragments(t *testing.T) {
func TestFragmentReassemblyTimeout(t *testing.T) {
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
nicID = 1
hoplimit = 255
@@ -2402,8 +2426,12 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: addr2.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv6EmptySubnet,
@@ -2645,11 +2673,15 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route {
t.Fatalf("CreateNIC(1, _) failed: %s", err)
}
const (
- src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ src = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ dst = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
)
- if err := s.AddAddress(1, ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: src.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")
@@ -3297,8 +3329,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err)
}
incomingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: incomingIPv6Addr}
- if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv6ProtoAddr, err)
+ if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingIPv6ProtoAddr, err)
}
outgoingEndpoint := channel.New(1, header.IPv6MinimumMTU, "")
@@ -3306,8 +3338,8 @@ func TestForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err)
}
outgoingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: outgoingIPv6Addr}
- if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv6ProtoAddr, err)
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingIPv6ProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -3341,7 +3373,8 @@ func TestForwarding(t *testing.T) {
ipHeaderLength := header.IPv6MinimumSize
icmpHeaderLength := header.ICMPv6MinimumSize
- totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength + extHdrLen
+ payloadLength := icmpHeaderLength + test.payloadLength + extHdrLen
+ totalLength := ipHeaderLength + payloadLength
hdr := buffer.NewPrependable(totalLength)
hdr.Prepend(test.payloadLength)
icmpH := header.ICMPv6(hdr.Prepend(icmpHeaderLength))
@@ -3359,7 +3392,7 @@ func TestForwarding(t *testing.T) {
copy(hdr.Prepend(extHdrLen), extHdrBytes)
ip := header.IPv6(hdr.Prepend(ipHeaderLength))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(header.ICMPv6MinimumSize + test.payloadLength),
+ PayloadLength: uint16(payloadLength),
TransportProtocol: transportProtocol,
HopLimit: test.TTL,
SrcAddr: test.sourceAddr,
@@ -3489,3 +3522,149 @@ func TestMultiCounterStatsInitialization(t *testing.T) {
t.Error(err)
}
}
+
+func TestIcmpRateLimit(t *testing.T) {
+ var (
+ host1IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10::1").To16()),
+ PrefixLen: 64,
+ },
+ }
+ host2IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10::2").To16()),
+ PrefixLen: 64,
+ },
+ }
+ )
+ const icmpBurst = 5
+ e := channel.New(1, defaultMTU, tcpip.LinkAddress(""))
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: faketime.NewManualClock(),
+ })
+ s.SetICMPBurst(icmpBurst)
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, host1IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv6Addr, err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: nicID,
+ },
+ })
+ tests := []struct {
+ name string
+ createPacket func() buffer.View
+ check func(*testing.T, *channel.Endpoint, int)
+ }{
+ {
+ name: "echo",
+ createPacket: func() buffer.View {
+ totalLength := header.IPv6MinimumSize + header.ICMPv6MinimumSize
+ hdr := buffer.NewPrependable(totalLength)
+ icmpH := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ icmpH.SetIdent(1)
+ icmpH.SetSequence(1)
+ icmpH.SetType(header.ICMPv6EchoRequest)
+ icmpH.SetCode(header.ICMPv6UnusedCode)
+ icmpH.SetChecksum(0)
+ icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: icmpH,
+ Src: host2IPv6Addr.AddressWithPrefix.Address,
+ Dst: host1IPv6Addr.AddressWithPrefix.Address,
+ }))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 1,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ return hdr.View()
+ },
+ check: func(t *testing.T, e *channel.Endpoint, round int) {
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected echo response, no packet read in endpoint in round %d", round)
+ }
+ if got, want := p.Proto, header.IPv6ProtocolNumber; got != want {
+ t.Errorf("got p.Proto = %d, want = %d", got, want)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoReply),
+ ))
+ },
+ },
+ {
+ name: "dst unreachable",
+ createPacket: func() buffer.View {
+ totalLength := header.IPv6MinimumSize + header.UDPMinimumSize
+ hdr := buffer.NewPrependable(totalLength)
+ udpH := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ udpH.Encode(&header.UDPFields{
+ SrcPort: 100,
+ DstPort: 101,
+ Length: header.UDPMinimumSize,
+ })
+
+ // Calculate the UDP checksum and set it.
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
+ sum = header.Checksum(nil, sum)
+ udpH.SetChecksum(^udpH.CalculateChecksum(sum))
+
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.UDPProtocolNumber,
+ HopLimit: 1,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ return hdr.View()
+ },
+ check: func(t *testing.T, e *channel.Endpoint, round int) {
+ p, ok := e.Read()
+ if round >= icmpBurst {
+ if ok {
+ t.Errorf("got packet %x in round %d, expected ICMP rate limit to stop it", p.Pkt.Data().Views(), round)
+ }
+ return
+ }
+ if !ok {
+ t.Fatalf("expected unreachable in round %d, no packet read in endpoint", round)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ ))
+ },
+ },
+ }
+ for _, testCase := range tests {
+ t.Run(testCase.name, func(t *testing.T) {
+ for round := 0; round < icmpBurst+1; round++ {
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: testCase.createPacket().ToVectorisedView(),
+ }))
+ testCase.check(t, e, round)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
index bc9cf6999..3e5c438d3 100644
--- a/pkg/tcpip/network/ipv6/mld_test.go
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -75,8 +75,12 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) {
// The stack will join an address's solicited node multicast address when
// an address is added. An MLD report message should be sent for the
// solicited-node group.
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
@@ -216,8 +220,13 @@ func TestSendQueuedMLDReports(t *testing.T) {
// Note, we will still expect to send a report for the global address's
// solicited node address from the unspecified address as per RFC 3590
// section 4.
- if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err)
+ properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ globalProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: globalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, globalProtocolAddr, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, globalProtocolAddr, properties, err)
}
reportCounter++
if got := reportStat.Value(); got != reportCounter {
@@ -252,8 +261,12 @@ func TestSendQueuedMLDReports(t *testing.T) {
// Adding a link-local address should send a report for its solicited node
// address and globalMulticastAddr.
- if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err)
+ linkLocalProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, linkLocalProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, linkLocalProtocolAddr, err)
}
if dadResolutionTime != 0 {
reportCounter++
@@ -567,8 +580,12 @@ func TestMLDSkipProtocol(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 8837d66d8..938427420 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -1130,7 +1130,11 @@ func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, config
return nil
}
- addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated)
+ addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.AddressProperties{
+ PEB: stack.FirstPrimaryEndpoint,
+ ConfigType: configType,
+ Deprecated: deprecated,
+ })
if err != nil {
panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err))
}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index f0186c64e..8297a7e10 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -144,8 +144,12 @@ func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf)
@@ -406,8 +410,12 @@ func TestNeighborSolicitationResponse(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: nicAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -602,8 +610,12 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf)
@@ -831,8 +843,12 @@ func TestNDPValidation(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber)
@@ -962,8 +978,12 @@ func TestNeighborAdvertisementValidation(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
ndpNASize := header.ICMPv6NeighborAdvertMinimumSize
@@ -1283,8 +1303,12 @@ func TestCheckDuplicateAddress(t *testing.T) {
checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonces[dadPacketsSent])}),
))
}
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: lladdr0.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
checkDADMsg()
diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go
index 1b96b1fb8..26640b7ee 100644
--- a/pkg/tcpip/network/multicast_group_test.go
+++ b/pkg/tcpip/network/multicast_group_test.go
@@ -151,15 +151,22 @@ func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.Link
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- addr := tcpip.AddressWithPrefix{
- Address: stackIPv4Addr,
- PrefixLen: defaultIPv4PrefixLength,
+ addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: stackIPv4Addr,
+ PrefixLen: defaultIPv4PrefixLength,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
}
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: linkLocalIPv6Addr1.WithPrefix(),
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, linkLocalIPv6Addr1, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
return s, clock
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index 009cab643..05b879543 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -146,8 +146,12 @@ func main() {
log.Fatal(err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, addr); err != nil {
- log.Fatal(err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ log.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Add default route.
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index c10b19aa0..a72afadda 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -124,13 +124,13 @@ func main() {
log.Fatalf("Bad IP address: %v", addrName)
}
- var addr tcpip.Address
+ var addrWithPrefix tcpip.AddressWithPrefix
var proto tcpip.NetworkProtocolNumber
if parsedAddr.To4() != nil {
- addr = tcpip.Address(parsedAddr.To4())
+ addrWithPrefix = tcpip.Address(parsedAddr.To4()).WithPrefix()
proto = ipv4.ProtocolNumber
} else if parsedAddr.To16() != nil {
- addr = tcpip.Address(parsedAddr.To16())
+ addrWithPrefix = tcpip.Address(parsedAddr.To16()).WithPrefix()
proto = ipv6.ProtocolNumber
} else {
log.Fatalf("Unknown IP type: %v", addrName)
@@ -176,11 +176,15 @@ func main() {
log.Fatal(err)
}
- if err := s.AddAddress(1, proto, addr); err != nil {
- log.Fatal(err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: proto,
+ AddressWithPrefix: addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ log.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
- subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr))))
+ subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addrWithPrefix.Address))), tcpip.AddressMask(strings.Repeat("\x00", len(addrWithPrefix.Address))))
if err != nil {
log.Fatal(err)
}
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index 34ac62444..b0b2d0afd 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -170,10 +170,14 @@ type SocketOptions struct {
// message is passed with incoming packets.
receiveTClassEnabled uint32
- // receivePacketInfoEnabled is used to specify if more inforamtion is
- // provided with incoming packets such as interface index and address.
+ // receivePacketInfoEnabled is used to specify if more information is
+ // provided with incoming IPv4 packets.
receivePacketInfoEnabled uint32
+ // receivePacketInfoEnabled is used to specify if more information is
+ // provided with incoming IPv6 packets.
+ receiveIPv6PacketInfoEnabled uint32
+
// hdrIncludeEnabled is used to indicate for a raw endpoint that all packets
// being written have an IP header and the endpoint should not attach an IP
// header.
@@ -360,6 +364,16 @@ func (so *SocketOptions) SetReceivePacketInfo(v bool) {
storeAtomicBool(&so.receivePacketInfoEnabled, v)
}
+// GetIPv6ReceivePacketInfo gets value for IPV6_RECVPKTINFO option.
+func (so *SocketOptions) GetIPv6ReceivePacketInfo() bool {
+ return atomic.LoadUint32(&so.receiveIPv6PacketInfoEnabled) != 0
+}
+
+// SetIPv6ReceivePacketInfo sets value for IPV6_RECVPKTINFO option.
+func (so *SocketOptions) SetIPv6ReceivePacketInfo(v bool) {
+ storeAtomicBool(&so.receiveIPv6PacketInfoEnabled, v)
+}
+
// GetHeaderIncluded gets value for IP_HDRINCL option.
func (so *SocketOptions) GetHeaderIncluded() bool {
return atomic.LoadUint32(&so.hdrIncludedEnabled) != 0
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 6c42ab29b..ead36880f 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -48,7 +48,6 @@ go_library(
"hook_string.go",
"icmp_rate_limit.go",
"iptables.go",
- "iptables_state.go",
"iptables_targets.go",
"iptables_types.go",
"neighbor_cache.go",
@@ -133,6 +132,7 @@ go_test(
name = "stack_test",
size = "small",
srcs = [
+ "conntrack_test.go",
"forwarding_test.go",
"neighbor_cache_test.go",
"neighbor_entry_test.go",
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index ae0bb4ace..7e4b5bf74 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -117,10 +117,10 @@ func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressS
}
// AddAndAcquirePermanentAddress implements AddressableEndpoint.
-func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) {
+func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error) {
a.mu.Lock()
defer a.mu.Unlock()
- ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */)
+ ep, err := a.addAndAcquireAddressLocked(addr, properties, true /* permanent */)
// From https://golang.org/doc/faq#nil_error:
//
// Under the covers, interfaces are implemented as two elements, a type T and
@@ -149,7 +149,7 @@ func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.Addr
func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, tcpip.Error) {
a.mu.Lock()
defer a.mu.Unlock()
- ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */)
+ ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: peb}, false /* permanent */)
// From https://golang.org/doc/faq#nil_error:
//
// Under the covers, interfaces are implemented as two elements, a type T and
@@ -180,7 +180,7 @@ func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.Addr
// returned, regardless the kind of address that is being added.
//
// Precondition: a.mu must be write locked.
-func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, tcpip.Error) {
+func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, properties AddressProperties, permanent bool) (*addressState, tcpip.Error) {
// attemptAddToPrimary is false when the address is already in the primary
// address list.
attemptAddToPrimary := true
@@ -208,7 +208,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
// We now promote the address.
for i, s := range a.mu.primary {
if s == addrState {
- switch peb {
+ switch properties.PEB {
case CanBePrimaryEndpoint:
// The address is already in the primary address list.
attemptAddToPrimary = false
@@ -222,7 +222,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
case NeverPrimaryEndpoint:
a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...)
default:
- panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb))
+ panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB))
}
break
}
@@ -262,11 +262,11 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
}
// Acquire the address before returning it.
addrState.mu.refs++
- addrState.mu.deprecated = deprecated
- addrState.mu.configType = configType
+ addrState.mu.deprecated = properties.Deprecated
+ addrState.mu.configType = properties.ConfigType
if attemptAddToPrimary {
- switch peb {
+ switch properties.PEB {
case NeverPrimaryEndpoint:
case CanBePrimaryEndpoint:
a.mu.primary = append(a.mu.primary, addrState)
@@ -285,7 +285,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
a.mu.primary[0] = addrState
}
default:
- panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb))
+ panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB))
}
}
@@ -489,12 +489,12 @@ func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tc
// Proceed to add a new temporary endpoint.
addr := localAddr.WithPrefix()
- ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */)
+ ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: tempPEB}, false /* permanent */)
if err != nil {
// addAndAcquireAddressLocked only returns an error if the address is
// already assigned but we just checked above if the address exists so we
// expect no error.
- panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err))
+ panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, AddressProperties{PEB: %s}, false): %s", addr, tempPEB, err))
}
// From https://golang.org/doc/faq#nil_error:
diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go
index 140f146f6..c55f85743 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state_test.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go
@@ -38,9 +38,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) {
}
{
- ep, err := s.AddAndAcquirePermanentAddress(addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */)
+ ep, err := s.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{PEB: stack.NeverPrimaryEndpoint})
if err != nil {
- t.Fatalf("s.AddAndAcquirePermanentAddress(%s, %d, %d, false): %s", addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, err)
+ t.Fatalf("s.AddAndAcquirePermanentAddress(%s, AddressProperties{PEB: NeverPrimaryEndpoint}): %s", addr, err)
}
// We don't need the address endpoint.
ep.DecRef()
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 068dab7ce..a3f403855 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -37,23 +37,9 @@ import (
// Our hash table has 16K buckets.
const numBuckets = 1 << 14
-// Direction of the tuple.
-type direction int
-
-const (
- dirOriginal direction = iota
- dirReply
-)
-
-// Manipulation type for the connection.
-// TODO(gvisor.dev/issue/5696): Define this as a bit set and support SNAT and
-// DNAT at the same time.
-type manipType int
-
const (
- manipNone manipType = iota
- manipSource
- manipDestination
+ establishedTimeout time.Duration = 5 * 24 * time.Hour
+ unestablishedTimeout time.Duration = 120 * time.Second
)
// tuple holds a connection's identifying and manipulating data in one
@@ -64,13 +50,22 @@ type tuple struct {
// tupleEntry is used to build an intrusive list of tuples.
tupleEntry
- tupleID
-
// conn is the connection tracking entry this tuple belongs to.
conn *conn
- // direction is the direction of the tuple.
- direction direction
+ // reply is true iff the tuple's direction is opposite that of the first
+ // packet seen on the connection.
+ reply bool
+
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
+ tupleID tupleID
+}
+
+func (t *tuple) id() tupleID {
+ t.mu.RLock()
+ defer t.mu.RUnlock()
+ return t.tupleID
}
// tupleID uniquely identifies a connection in one direction. It currently
@@ -103,50 +98,43 @@ func (ti tupleID) reply() tupleID {
//
// +stateify savable
type conn struct {
+ ct *ConnTrack
+
// original is the tuple in original direction. It is immutable.
original tuple
- // reply is the tuple in reply direction. It is immutable.
+ // reply is the tuple in reply direction.
reply tuple
- // manip indicates if the packet should be manipulated. It is immutable.
- // TODO(gvisor.dev/issue/5696): Support updating manipulation type.
- manip manipType
-
- // tcbHook indicates if the packet is inbound or outbound to
- // update the state of tcb. It is immutable.
- tcbHook Hook
-
- // mu protects all mutable state.
- mu sync.Mutex `state:"nosave"`
+ mu sync.RWMutex `state:"nosave"`
+ // Indicates that the connection has been finalized and may handle replies.
+ //
+ // +checklocks:mu
+ finalized bool
+ // sourceManip indicates the packet's source is manipulated.
+ //
+ // +checklocks:mu
+ sourceManip bool
+ // destinationManip indicates the packet's destination is manipulated.
+ //
+ // +checklocks:mu
+ destinationManip bool
// tcb is TCB control block. It is used to keep track of states
- // of tcp connection and is protected by mu.
+ // of tcp connection.
+ //
+ // +checklocks:mu
tcb tcpconntrack.TCB
// lastUsed is the last time the connection saw a relevant packet, and
- // is updated by each packet on the connection. It is protected by mu.
+ // is updated by each packet on the connection.
//
- // TODO(gvisor.dev/issue/5939): do not use the ambient clock.
- lastUsed time.Time `state:".(unixTime)"`
-}
-
-// newConn creates new connection.
-func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
- conn := conn{
- manip: manip,
- tcbHook: hook,
- lastUsed: time.Now(),
- }
- conn.original = tuple{conn: &conn, tupleID: orig}
- conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
- return &conn
+ // +checklocks:mu
+ lastUsed tcpip.MonotonicTime
}
// timedOut returns whether the connection timed out based on its state.
-func (cn *conn) timedOut(now time.Time) bool {
- const establishedTimeout = 5 * 24 * time.Hour
- const defaultTimeout = 120 * time.Second
- cn.mu.Lock()
- defer cn.mu.Unlock()
+func (cn *conn) timedOut(now tcpip.MonotonicTime) bool {
+ cn.mu.RLock()
+ defer cn.mu.RUnlock()
if cn.tcb.State() == tcpconntrack.ResultAlive {
// Use the same default as Linux, which doesn't delete
// established connections for 5(!) days.
@@ -154,22 +142,31 @@ func (cn *conn) timedOut(now time.Time) bool {
}
// Use the same default as Linux, which lets connections in most states
// other than established remain for <= 120 seconds.
- return now.Sub(cn.lastUsed) > defaultTimeout
+ return now.Sub(cn.lastUsed) > unestablishedTimeout
}
// update the connection tracking state.
//
-// Precondition: cn.mu must be held.
-func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
+// +checklocks:cn.mu
+func (cn *conn) updateLocked(pkt *PacketBuffer, reply bool) {
+ if pkt.TransportProtocolNumber != header.TCPProtocolNumber {
+ return
+ }
+
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
+
// Update the state of tcb. tcb assumes it's always initialized on the
// client. However, we only need to know whether the connection is
// established or not, so the client/server distinction isn't important.
if cn.tcb.IsEmpty() {
cn.tcb.Init(tcpHeader)
- } else if hook == cn.tcbHook {
- cn.tcb.UpdateStateOutbound(tcpHeader)
- } else {
+ return
+ }
+
+ if reply {
cn.tcb.UpdateStateInbound(tcpHeader)
+ } else {
+ cn.tcb.UpdateStateOutbound(tcpHeader)
}
}
@@ -194,44 +191,37 @@ type ConnTrack struct {
// It is immutable.
seed uint32
+ // clock provides timing used to determine conntrack reapings.
+ clock tcpip.Clock
+
+ mu sync.RWMutex `state:"nosave"`
// mu protects the buckets slice, but not buckets' contents. Only take
// the write lock if you are modifying the slice or saving for S/R.
- mu sync.RWMutex `state:"nosave"`
-
- // buckets is protected by mu.
+ //
+ // +checklocks:mu
buckets []bucket
}
// +stateify savable
type bucket struct {
- // mu protects tuples.
- mu sync.Mutex `state:"nosave"`
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
tuples tupleList
}
-// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
-// TCP header.
-//
-// Preconditions: pkt.NetworkHeader() is valid.
-func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) {
- netHeader := pkt.Network()
- if netHeader.TransportProtocol() != header.TCPProtocolNumber {
- return tupleID{}, &tcpip.ErrUnknownProtocol{}
- }
-
- tcpHeader := header.TCP(pkt.TransportHeader().View())
- if len(tcpHeader) < header.TCPMinimumSize {
- return tupleID{}, &tcpip.ErrUnknownProtocol{}
+func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool) {
+ switch pkt.TransportProtocolNumber {
+ case header.TCPProtocolNumber:
+ if tcpHeader := header.TCP(pkt.TransportHeader().View()); len(tcpHeader) >= header.TCPMinimumSize {
+ return tcpHeader, true
+ }
+ case header.UDPProtocolNumber:
+ if udpHeader := header.UDP(pkt.TransportHeader().View()); len(udpHeader) >= header.UDPMinimumSize {
+ return udpHeader, true
+ }
}
- return tupleID{
- srcAddr: netHeader.SourceAddress(),
- srcPort: tcpHeader.SourcePort(),
- dstAddr: netHeader.DestinationAddress(),
- dstPort: tcpHeader.DestinationPort(),
- transProto: netHeader.TransportProtocol(),
- netProto: pkt.NetworkProtocolNumber,
- }, nil
+ return nil, false
}
func (ct *ConnTrack) init() {
@@ -240,278 +230,285 @@ func (ct *ConnTrack) init() {
ct.buckets = make([]bucket, numBuckets)
}
-// connFor gets the conn for pkt if it exists, or returns nil
-// if it does not. It returns an error when pkt does not contain a valid TCP
-// header.
-// TODO(gvisor.dev/issue/6168): Support UDP.
-func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return nil, dirOriginal
+func (ct *ConnTrack) getConnOrMaybeInsertNoop(pkt *PacketBuffer) *tuple {
+ netHeader := pkt.Network()
+ transportHeader, ok := getTransportHeader(pkt)
+ if !ok {
+ return nil
+ }
+
+ tid := tupleID{
+ srcAddr: netHeader.SourceAddress(),
+ srcPort: transportHeader.SourcePort(),
+ dstAddr: netHeader.DestinationAddress(),
+ dstPort: transportHeader.DestinationPort(),
+ transProto: pkt.TransportProtocolNumber,
+ netProto: pkt.NetworkProtocolNumber,
}
- return ct.connForTID(tid)
-}
-func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
- bucket := ct.bucket(tid)
- now := time.Now()
+ bktID := ct.bucket(tid)
ct.mu.RLock()
- defer ct.mu.RUnlock()
- ct.buckets[bucket].mu.Lock()
- defer ct.buckets[bucket].mu.Unlock()
-
- // Iterate over the tuples in a bucket, cleaning up any unused
- // connections we find.
- for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() {
- // Clean up any timed-out connections we happen to find.
- if ct.reapTupleLocked(other, bucket, now) {
- // The tuple expired.
- continue
- }
- if tid == other.tupleID {
- return other.conn, other.direction
- }
+ bkt := &ct.buckets[bktID]
+ ct.mu.RUnlock()
+
+ now := ct.clock.NowMonotonic()
+ if t := bkt.connForTID(tid, now); t != nil {
+ return t
}
- return nil, dirOriginal
-}
+ bkt.mu.Lock()
+ defer bkt.mu.Unlock()
-func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return nil
+ // Make sure a connection wasn't added between when we last checked the
+ // bucket and acquired the bucket's write lock.
+ if t := bkt.connForTIDRLocked(tid, now); t != nil {
+ return t
}
- if hook != Prerouting && hook != Output {
- return nil
+
+ // This is the first packet we're seeing for the connection. Create an entry
+ // for this new connection.
+ conn := &conn{
+ ct: ct,
+ original: tuple{tupleID: tid},
+ reply: tuple{tupleID: tid.reply(), reply: true},
+ lastUsed: now,
}
+ conn.original.conn = conn
+ conn.reply.conn = conn
- replyTID := tid.reply()
- replyTID.srcAddr = address
- replyTID.srcPort = port
+ // For now, we only map an entry for the packet's original tuple as NAT may be
+ // performed on this connection. Until the packet goes through all the hooks
+ // and its final address/port is known, we cannot know what the response
+ // packet's addresses/ports will look like.
+ //
+ // This is okay because the destination cannot send its response until it
+ // receives the packet; the packet will only be received once all the hooks
+ // have been performed.
+ //
+ // See (*conn).finalize.
+ bkt.tuples.PushFront(&conn.original)
+ return &conn.original
+}
- conn, _ := ct.connForTID(tid)
- if conn != nil {
- // The connection is already tracked.
- // TODO(gvisor.dev/issue/5696): Support updating an existing connection.
- return nil
- }
- conn = newConn(tid, replyTID, manipDestination, hook)
- ct.insertConn(conn)
- return conn
+func (ct *ConnTrack) connForTID(tid tupleID) *tuple {
+ bktID := ct.bucket(tid)
+
+ ct.mu.RLock()
+ bkt := &ct.buckets[bktID]
+ ct.mu.RUnlock()
+
+ return bkt.connForTID(tid, ct.clock.NowMonotonic())
}
-func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return nil
- }
- if hook != Input && hook != Postrouting {
- return nil
+func (bkt *bucket) connForTID(tid tupleID, now tcpip.MonotonicTime) *tuple {
+ bkt.mu.RLock()
+ defer bkt.mu.RUnlock()
+ return bkt.connForTIDRLocked(tid, now)
+}
+
+// +checklocksread:bkt.mu
+func (bkt *bucket) connForTIDRLocked(tid tupleID, now tcpip.MonotonicTime) *tuple {
+ for other := bkt.tuples.Front(); other != nil; other = other.Next() {
+ if tid == other.id() && !other.conn.timedOut(now) {
+ return other
+ }
}
+ return nil
+}
- replyTID := tid.reply()
- replyTID.dstAddr = address
- replyTID.dstPort = port
+func (ct *ConnTrack) finalize(cn *conn) {
+ tid := cn.reply.id()
+ id := ct.bucket(tid)
- conn, _ := ct.connForTID(tid)
- if conn != nil {
- // The connection is already tracked.
- // TODO(gvisor.dev/issue/5696): Support updating an existing connection.
- return nil
+ ct.mu.RLock()
+ bkt := &ct.buckets[id]
+ ct.mu.RUnlock()
+
+ bkt.mu.Lock()
+ defer bkt.mu.Unlock()
+
+ if t := bkt.connForTIDRLocked(tid, ct.clock.NowMonotonic()); t != nil {
+ // Another connection for the reply already exists. We can't do much about
+ // this so we leave the connection cn represents in a state where it can
+ // send packets but its responses will be mapped to some other connection.
+ // This may be okay if the connection only expects to send packets without
+ // any responses.
+ return
}
- conn = newConn(tid, replyTID, manipSource, hook)
- ct.insertConn(conn)
- return conn
+
+ bkt.tuples.PushFront(&cn.reply)
}
-// insertConn inserts conn into the appropriate table bucket.
-func (ct *ConnTrack) insertConn(conn *conn) {
- // Lock the buckets in the correct order.
- tupleBucket := ct.bucket(conn.original.tupleID)
- replyBucket := ct.bucket(conn.reply.tupleID)
- ct.mu.RLock()
- defer ct.mu.RUnlock()
- if tupleBucket < replyBucket {
- ct.buckets[tupleBucket].mu.Lock()
- ct.buckets[replyBucket].mu.Lock()
- } else if tupleBucket > replyBucket {
- ct.buckets[replyBucket].mu.Lock()
- ct.buckets[tupleBucket].mu.Lock()
- } else {
- // Both tuples are in the same bucket.
- ct.buckets[tupleBucket].mu.Lock()
- }
-
- // Now that we hold the locks, ensure the tuple hasn't been inserted by
- // another thread.
- // TODO(gvisor.dev/issue/5773): Should check conn.reply.tupleID, too?
- alreadyInserted := false
- for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() {
- if other.tupleID == conn.original.tupleID {
- alreadyInserted = true
- break
+func (cn *conn) finalize() {
+ {
+ cn.mu.RLock()
+ finalized := cn.finalized
+ cn.mu.RUnlock()
+ if finalized {
+ return
}
}
- if !alreadyInserted {
- // Add the tuple to the map.
- ct.buckets[tupleBucket].tuples.PushFront(&conn.original)
- ct.buckets[replyBucket].tuples.PushFront(&conn.reply)
+ cn.mu.Lock()
+ finalized := cn.finalized
+ cn.finalized = true
+ cn.mu.Unlock()
+ if finalized {
+ return
}
- // Unlocking can happen in any order.
- ct.buckets[tupleBucket].mu.Unlock()
- if tupleBucket != replyBucket {
- ct.buckets[replyBucket].mu.Unlock() // +checklocksforce
- }
+ cn.ct.finalize(cn)
}
-// handlePacket will manipulate the port and address of the packet if the
-// connection exists. Returns whether, after the packet traverses the tables,
-// it should create a new entry in the table.
-func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
- if pkt.NatDone {
- return false
+// performNAT setups up the connection for the specified NAT.
+//
+// Generally, only the first packet of a connection reaches this method; other
+// other packets will be manipulated without needing to modify the connection.
+func (cn *conn) performNAT(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, dnat bool) {
+ cn.performNATIfNoop(port, address, dnat)
+ cn.handlePacket(pkt, hook, r)
+}
+
+func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool) {
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
+
+ if cn.finalized {
+ return
}
- switch hook {
- case Prerouting, Input, Output, Postrouting:
- default:
- return false
+ if dnat {
+ if cn.destinationManip {
+ return
+ }
+ cn.destinationManip = true
+ } else {
+ if cn.sourceManip {
+ return
+ }
+ cn.sourceManip = true
}
- // TODO(gvisor.dev/issue/6168): Support UDP.
- if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
+ cn.reply.mu.Lock()
+ defer cn.reply.mu.Unlock()
+
+ if dnat {
+ cn.reply.tupleID.srcAddr = address
+ cn.reply.tupleID.srcPort = port
+ } else {
+ cn.reply.tupleID.dstAddr = address
+ cn.reply.tupleID.dstPort = port
+ }
+}
+
+// handlePacket attempts to handle a packet and perform NAT if the connection
+// has had NAT performed on it.
+//
+// Returns true if the packet can skip the NAT table.
+func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, rt *Route) bool {
+ transportHeader, ok := getTransportHeader(pkt)
+ if !ok {
return false
}
- conn, dir := ct.connFor(pkt)
- // Connection not found for the packet.
- if conn == nil {
- // If this is the last hook in the data path for this packet (Input if
- // incoming, Postrouting if outgoing), indicate that a connection should be
- // inserted by the end of this hook.
- return hook == Input || hook == Postrouting
+ fullChecksum := false
+ updatePseudoHeader := false
+ natDone := &pkt.SNATDone
+ dnat := false
+ switch hook {
+ case Prerouting:
+ // Packet came from outside the stack so it must have a checksum set
+ // already.
+ fullChecksum = true
+ updatePseudoHeader = true
+
+ natDone = &pkt.DNATDone
+ dnat = true
+ case Input:
+ case Forward:
+ panic("should not handle packet in the forwarding hook")
+ case Output:
+ natDone = &pkt.DNATDone
+ dnat = true
+ fallthrough
+ case Postrouting:
+ if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
+ updatePseudoHeader = true
+ } else if rt.RequiresTXTransportChecksum() {
+ fullChecksum = true
+ updatePseudoHeader = true
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized hook = %d", hook))
}
- netHeader := pkt.Network()
- tcpHeader := header.TCP(pkt.TransportHeader().View())
- if len(tcpHeader) < header.TCPMinimumSize {
- return false
+ if *natDone {
+ panic(fmt.Sprintf("packet already had NAT(dnat=%t) performed at hook=%s; pkt=%#v", dnat, hook, pkt))
}
// TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be
// validated if checksum offloading is off. It may require IP defrag if the
// packets are fragmented.
- var newAddr tcpip.Address
- var newPort uint16
-
- updateSRCFields := false
-
- switch hook {
- case Prerouting, Output:
- if conn.manip == manipDestination {
- switch dir {
- case dirOriginal:
- newPort = conn.reply.srcPort
- newAddr = conn.reply.srcAddr
- case dirReply:
- newPort = conn.original.dstPort
- newAddr = conn.original.dstAddr
-
- updateSRCFields = true
+ reply := pkt.tuple.reply
+ tid, performManip := func() (tupleID, bool) {
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
+
+ // Mark the connection as having been used recently so it isn't reaped.
+ cn.lastUsed = cn.ct.clock.NowMonotonic()
+ // Update connection state.
+ cn.updateLocked(pkt, reply)
+
+ var tuple *tuple
+ if reply {
+ if dnat {
+ if !cn.sourceManip {
+ return tupleID{}, false
+ }
+ } else if !cn.destinationManip {
+ return tupleID{}, false
}
- pkt.NatDone = true
- }
- case Input, Postrouting:
- if conn.manip == manipSource {
- switch dir {
- case dirOriginal:
- newPort = conn.reply.dstPort
- newAddr = conn.reply.dstAddr
-
- updateSRCFields = true
- case dirReply:
- newPort = conn.original.srcPort
- newAddr = conn.original.srcAddr
+
+ tuple = &cn.original
+ } else {
+ if dnat {
+ if !cn.destinationManip {
+ return tupleID{}, false
+ }
+ } else if !cn.sourceManip {
+ return tupleID{}, false
}
- pkt.NatDone = true
+
+ tuple = &cn.reply
}
- default:
- panic(fmt.Sprintf("unrecognized hook = %s", hook))
- }
- if !pkt.NatDone {
+
+ return tuple.id(), true
+ }()
+ if !performManip {
return false
}
- fullChecksum := false
- updatePseudoHeader := false
- switch hook {
- case Prerouting, Input:
- case Output, Postrouting:
- // Calculate the TCP checksum and set it.
- if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
- updatePseudoHeader = true
- } else if r.RequiresTXTransportChecksum() {
- fullChecksum = true
- updatePseudoHeader = true
- }
- default:
- panic(fmt.Sprintf("unrecognized hook = %s", hook))
+ newPort := tid.dstPort
+ newAddr := tid.dstAddr
+ if dnat {
+ newPort = tid.srcPort
+ newAddr = tid.srcAddr
}
rewritePacket(
- netHeader,
- tcpHeader,
- updateSRCFields,
+ pkt.Network(),
+ transportHeader,
+ !dnat,
fullChecksum,
updatePseudoHeader,
newPort,
newAddr,
)
- // Update the state of tcb.
- conn.mu.Lock()
- defer conn.mu.Unlock()
-
- // Mark the connection as having been used recently so it isn't reaped.
- conn.lastUsed = time.Now()
- // Update connection state.
- conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
-
- return false
-}
-
-// maybeInsertNoop tries to insert a no-op connection entry to keep connections
-// from getting clobbered when replies arrive. It only inserts if there isn't
-// already a connection for pkt.
-//
-// This should be called after traversing iptables rules only, to ensure that
-// pkt.NatDone is set correctly.
-func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
- // If there were a rule applying to this packet, it would be marked
- // with NatDone.
- if pkt.NatDone {
- return
- }
-
- // We only track TCP connections.
- if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
- return
- }
-
- // This is the first packet we're seeing for the TCP connection. Insert
- // the noop entry (an identity mapping) so that the response doesn't
- // get NATed, breaking the connection.
- tid, err := packetToTupleID(pkt)
- if err != nil {
- return
- }
- conn := newConn(tid, tid.reply(), manipNone, hook)
- conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
- ct.insertConn(conn)
+ *natDone = true
+ return true
}
// bucket gets the conntrack bucket for a tupleID.
@@ -555,7 +552,7 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim
const minInterval = 10 * time.Millisecond
const maxInterval = maxFullTraversal / fractionPerReaping
- now := time.Now()
+ now := ct.clock.NowMonotonic()
checked := 0
expired := 0
var idx int
@@ -563,14 +560,20 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim
defer ct.mu.RUnlock()
for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ {
idx = (i + start) % len(ct.buckets)
- ct.buckets[idx].mu.Lock()
- for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() {
+ bkt := &ct.buckets[idx]
+ bkt.mu.Lock()
+ for tuple := bkt.tuples.Front(); tuple != nil; {
+ // reapTupleLocked updates tuple's next pointer so we grab it here.
+ nextTuple := tuple.Next()
+
checked++
- if ct.reapTupleLocked(tuple, idx, now) {
+ if ct.reapTupleLocked(tuple, idx, bkt, now) {
expired++
}
+
+ tuple = nextTuple
}
- ct.buckets[idx].mu.Unlock()
+ bkt.mu.Unlock()
}
// We already checked buckets[idx].
idx++
@@ -595,44 +598,51 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim
// reapTupleLocked tries to remove tuple and its reply from the table. It
// returns whether the tuple's connection has timed out.
//
-// Preconditions:
-// * ct.mu is locked for reading.
-// * bucket is locked.
-func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool {
+// Precondition: ct.mu is read locked and bkt.mu is write locked.
+// +checklocksread:ct.mu
+// +checklocks:bkt.mu
+func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool {
if !tuple.conn.timedOut(now) {
return false
}
- // To maintain lock order, we can only reap these tuples if the reply
- // appears later in the table.
- replyBucket := ct.bucket(tuple.reply())
- if bucket > replyBucket {
+ // To maintain lock order, we can only reap both tuples if the reply appears
+ // later in the table.
+ replyBktID := ct.bucket(tuple.id().reply())
+ tuple.conn.mu.RLock()
+ replyTupleInserted := tuple.conn.finalized
+ tuple.conn.mu.RUnlock()
+ if bktID > replyBktID && replyTupleInserted {
return true
}
- // Don't re-lock if both tuples are in the same bucket.
- differentBuckets := bucket != replyBucket
- if differentBuckets {
- ct.buckets[replyBucket].mu.Lock()
+ // Reap the reply.
+ if replyTupleInserted {
+ // Don't re-lock if both tuples are in the same bucket.
+ if bktID != replyBktID {
+ replyBkt := &ct.buckets[replyBktID]
+ replyBkt.mu.Lock()
+ removeConnFromBucket(replyBkt, tuple)
+ replyBkt.mu.Unlock()
+ } else {
+ removeConnFromBucket(bkt, tuple)
+ }
}
- // We have the buckets locked and can remove both tuples.
- if tuple.direction == dirOriginal {
- ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply)
- } else {
- ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original)
- }
- ct.buckets[bucket].tuples.Remove(tuple)
+ bkt.tuples.Remove(tuple)
+ return true
+}
- // Don't re-unlock if both tuples are in the same bucket.
- if differentBuckets {
- ct.buckets[replyBucket].mu.Unlock() // +checklocksforce
+// +checklocks:b.mu
+func removeConnFromBucket(b *bucket, tuple *tuple) {
+ if tuple.reply {
+ b.tuples.Remove(&tuple.conn.original)
+ } else {
+ b.tuples.Remove(&tuple.conn.reply)
}
-
- return true
}
-func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
+func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
// Lookup the connection. The reply's original destination
// describes the original address.
tid := tupleID{
@@ -640,17 +650,22 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ
srcPort: epID.LocalPort,
dstAddr: epID.RemoteAddress,
dstPort: epID.RemotePort,
- transProto: header.TCPProtocolNumber,
+ transProto: transProto,
netProto: netProto,
}
- conn, _ := ct.connForTID(tid)
- if conn == nil {
+ t := ct.connForTID(tid)
+ if t == nil {
// Not a tracked connection.
return "", 0, &tcpip.ErrNotConnected{}
- } else if conn.manip != manipDestination {
+ }
+
+ t.conn.mu.RLock()
+ defer t.conn.mu.RUnlock()
+ if !t.conn.destinationManip {
// Unmanipulated destination.
return "", 0, &tcpip.ErrInvalidOptionValue{}
}
- return conn.original.dstAddr, conn.original.dstPort, nil
+ id := t.conn.original.id()
+ return id.dstAddr, id.dstPort, nil
}
diff --git a/pkg/tcpip/stack/conntrack_test.go b/pkg/tcpip/stack/conntrack_test.go
new file mode 100644
index 000000000..fb0645ed1
--- /dev/null
+++ b/pkg/tcpip/stack/conntrack_test.go
@@ -0,0 +1,132 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
+)
+
+func TestReap(t *testing.T) {
+ // Initialize conntrack.
+ clock := faketime.NewManualClock()
+ ct := ConnTrack{
+ clock: clock,
+ }
+ ct.init()
+ ct.checkNumTuples(t, 0)
+
+ // Simulate sending a SYN. This will get the connection into conntrack, but
+ // the connection won't be considered established. Thus the timeout for
+ // reaping is unestablishedTimeout.
+ pkt1 := genTCPPacket()
+ pkt1.tuple = ct.getConnOrMaybeInsertNoop(pkt1)
+ // We set rt.routeInfo.Loop to avoid a panic when handlePacket calls
+ // rt.RequiresTXTransportChecksum.
+ var rt Route
+ rt.routeInfo.Loop = PacketLoop
+ if pkt1.tuple.conn.handlePacket(pkt1, Output, &rt) {
+ t.Fatal("handlePacket() shouldn't perform any NAT")
+ }
+ ct.checkNumTuples(t, 1)
+
+ // Travel a little into the future and send the same SYN. This should update
+ // lastUsed, but per #6748 didn't.
+ clock.Advance(unestablishedTimeout / 2)
+ pkt2 := genTCPPacket()
+ pkt2.tuple = ct.getConnOrMaybeInsertNoop(pkt2)
+ if pkt2.tuple.conn.handlePacket(pkt2, Output, &rt) {
+ t.Fatal("handlePacket() shouldn't perform any NAT")
+ }
+ ct.checkNumTuples(t, 1)
+
+ // Travel farther into the future - enough that failing to update lastUsed
+ // would cause a reaping - and reap the whole table. Make sure the connection
+ // hasn't been reaped.
+ clock.Advance(unestablishedTimeout * 3 / 4)
+ ct.reapEverything()
+ ct.checkNumTuples(t, 1)
+
+ // Travel past unestablishedTimeout to confirm the tuple is gone.
+ clock.Advance(unestablishedTimeout / 2)
+ ct.reapEverything()
+ ct.checkNumTuples(t, 0)
+}
+
+// genTCPPacket returns an initialized IPv4 TCP packet.
+func genTCPPacket() *PacketBuffer {
+ const packetLen = header.IPv4MinimumSize + header.TCPMinimumSize
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: packetLen,
+ })
+ pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
+ pkt.TransportProtocolNumber = header.TCPProtocolNumber
+ tcpHdr := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize))
+ tcpHdr.Encode(&header.TCPFields{
+ SrcPort: 5555,
+ DstPort: 6666,
+ SeqNum: 7777,
+ AckNum: 8888,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 50000,
+ Checksum: 0, // Conntrack doesn't verify the checksum.
+ })
+ ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
+ ipHdr.Encode(&header.IPv4Fields{
+ TotalLength: packetLen,
+ Protocol: uint8(header.TCPProtocolNumber),
+ SrcAddr: testutil.MustParse4("1.0.0.1"),
+ DstAddr: testutil.MustParse4("1.0.0.2"),
+ Checksum: 0, // Conntrack doesn't verify the checksum.
+ })
+
+ return pkt
+}
+
+// checkNumTuples checks that there are exactly want tuples tracked by
+// conntrack.
+func (ct *ConnTrack) checkNumTuples(t *testing.T, want int) {
+ t.Helper()
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+
+ var total int
+ for idx := range ct.buckets {
+ ct.buckets[idx].mu.RLock()
+ total += ct.buckets[idx].tuples.Len()
+ ct.buckets[idx].mu.RUnlock()
+ }
+
+ if total != want {
+ t.Fatalf("checkNumTuples: got %d, wanted %d", total, want)
+ }
+}
+
+func (ct *ConnTrack) reapEverything() {
+ var bucket int
+ for {
+ newBucket, _ := ct.reapUnused(bucket, 0 /* ignored */)
+ // We started reaping at bucket 0. If the next bucket isn't after our
+ // current bucket, we've gone through them all.
+ if newBucket <= bucket {
+ break
+ }
+ bucket = newBucket
+ }
+}
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index ccb69393b..c2f1f4798 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -181,10 +181,6 @@ func (*fwdTestNetworkProtocol) MinimumPacketSize() int {
return fwdTestNetHeaderLen
}
-func (*fwdTestNetworkProtocol) DefaultPrefixLen() int {
- return fwdTestNetDefaultPrefixLen
-}
-
func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
}
@@ -384,8 +380,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.M
if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC #1 failed:", err)
}
- if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress #1 failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fwdTestNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fwdTestNetDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
// NIC 2 has the link address "b", and added the network address 2.
@@ -397,8 +400,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.M
if err := s.CreateNIC(2, ep2); err != nil {
t.Fatal("CreateNIC #2 failed:", err)
}
- if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress #2 failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fwdTestNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fwdTestNetDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr2, AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err)
}
nic, ok := s.nics[2]
diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go
index 3a20839da..99e5d2df7 100644
--- a/pkg/tcpip/stack/icmp_rate_limit.go
+++ b/pkg/tcpip/stack/icmp_rate_limit.go
@@ -16,6 +16,7 @@ package stack
import (
"golang.org/x/time/rate"
+ "gvisor.dev/gvisor/pkg/tcpip"
)
const (
@@ -31,11 +32,41 @@ const (
// ICMPRateLimiter is a global rate limiter that controls the generation of
// ICMP messages generated by the stack.
type ICMPRateLimiter struct {
- *rate.Limiter
+ limiter *rate.Limiter
+ clock tcpip.Clock
}
// NewICMPRateLimiter returns a global rate limiter for controlling the rate
-// at which ICMP messages are generated by the stack.
-func NewICMPRateLimiter() *ICMPRateLimiter {
- return &ICMPRateLimiter{Limiter: rate.NewLimiter(icmpLimit, icmpBurst)}
+// at which ICMP messages are generated by the stack. The returned limiter
+// does not apply limits to any ICMP types by default.
+func NewICMPRateLimiter(clock tcpip.Clock) *ICMPRateLimiter {
+ return &ICMPRateLimiter{
+ clock: clock,
+ limiter: rate.NewLimiter(icmpLimit, icmpBurst),
+ }
+}
+
+// SetLimit sets a new Limit for the limiter.
+func (l *ICMPRateLimiter) SetLimit(limit rate.Limit) {
+ l.limiter.SetLimitAt(l.clock.Now(), limit)
+}
+
+// Limit returns the maximum overall event rate.
+func (l *ICMPRateLimiter) Limit() rate.Limit {
+ return l.limiter.Limit()
+}
+
+// SetBurst sets a new burst size for the limiter.
+func (l *ICMPRateLimiter) SetBurst(burst int) {
+ l.limiter.SetBurstAt(l.clock.Now(), burst)
+}
+
+// Burst returns the maximum burst size.
+func (l *ICMPRateLimiter) Burst() int {
+ return l.limiter.Burst()
+}
+
+// Allow reports whether one ICMP message may be sent now.
+func (l *ICMPRateLimiter) Allow() bool {
+ return l.limiter.AllowN(l.clock.Now(), 1)
}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index f152c0d83..fd61387bf 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -42,7 +42,7 @@ const reaperDelay = 5 * time.Second
// DefaultTables returns a default set of tables. Each chain is set to accept
// all packets.
-func DefaultTables(seed uint32) *IPTables {
+func DefaultTables(seed uint32, clock tcpip.Clock) *IPTables {
return &IPTables{
v4Tables: [NumTables]Table{
NATID: {
@@ -182,7 +182,8 @@ func DefaultTables(seed uint32) *IPTables {
Postrouting: {MangleID, NATID},
},
connections: ConnTrack{
- seed: seed,
+ seed: seed,
+ clock: clock,
},
reaperDone: make(chan struct{}, 1),
}
@@ -264,33 +265,125 @@ const (
chainReturn
)
-// Check runs pkt through the rules for hook. It returns true when the packet
-// should continue traversing the network stack and false when it should be
-// dropped.
+// CheckPrerouting performs the prerouting hook on the packet.
+//
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndpoint, inNicName string) bool {
+ const hook = Prerouting
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+
+ pkt.tuple = it.connections.getConnOrMaybeInsertNoop(pkt)
+
+ return it.check(hook, pkt, nil /* route */, addressEP, inNicName, "" /* outNicName */)
+}
+
+// CheckInput performs the input hook on the packet.
+//
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool {
+ const hook = Input
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+
+ ret := it.check(hook, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */)
+ if t := pkt.tuple; t != nil {
+ t.conn.finalize()
+ }
+ pkt.tuple = nil
+ return ret
+}
+
+// CheckForward performs the forward hook on the packet.
+//
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckForward(pkt *PacketBuffer, inNicName, outNicName string) bool {
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+ return it.check(Forward, pkt, nil /* route */, nil /* addressEP */, inNicName, outNicName)
+}
+
+// CheckOutput performs the output hook on the packet.
+//
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string) bool {
+ const hook = Output
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
+ return true
+ }
+
+ pkt.tuple = it.connections.getConnOrMaybeInsertNoop(pkt)
+
+ return it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName)
+}
+
+// CheckPostrouting performs the postrouting hook on the packet.
//
-// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool {
- if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber {
+// Returns true iff the packet may continue traversing the stack; the packet
+// must be dropped if false is returned.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, outNicName string) bool {
+ const hook = Postrouting
+
+ if it.shouldSkip(pkt.NetworkProtocolNumber) {
return true
}
+
+ ret := it.check(hook, pkt, r, addressEP, "" /* inNicName */, outNicName)
+ if t := pkt.tuple; t != nil {
+ t.conn.finalize()
+ }
+ pkt.tuple = nil
+ return ret
+}
+
+func (it *IPTables) shouldSkip(netProto tcpip.NetworkProtocolNumber) bool {
+ switch netProto {
+ case header.IPv4ProtocolNumber, header.IPv6ProtocolNumber:
+ default:
+ // IPTables only supports IPv4/IPv6.
+ return true
+ }
+
+ it.mu.RLock()
+ defer it.mu.RUnlock()
// Many users never configure iptables. Spare them the cost of rule
// traversal if rules have never been set.
+ return !it.modified
+}
+
+// check runs pkt through the rules for hook. It returns true when the packet
+// should continue traversing the network stack and false when it should be
+// dropped.
+//
+// Precondition: The packet's network and transport header must be set.
+func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool {
it.mu.RLock()
defer it.mu.RUnlock()
- if !it.modified {
- return true
- }
-
- // Packets are manipulated only if connection and matching
- // NAT rule exists.
- shouldTrack := it.connections.handlePacket(pkt, hook, r)
// Go through each table containing the hook.
priorities := it.priorities[hook]
for _, tableID := range priorities {
- // If handlePacket already NATed the packet, we don't need to
- // check the NAT table.
- if tableID == NATID && pkt.NatDone {
+ if t := pkt.tuple; t != nil && tableID == NATID && t.conn.handlePacket(pkt, hook, r) {
continue
}
var table Table
@@ -300,7 +393,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr
table = it.v4Tables[tableID]
}
ruleIdx := table.BuiltinChains[hook]
- switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict {
// If the table returns Accept, move on to the next table.
case chainAccept:
continue
@@ -311,7 +404,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr
// Any Return from a built-in chain means we have to
// call the underflow.
underflow := table.Rules[table.Underflows[hook]]
- switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, preroutingAddr); v {
+ switch v, _ := underflow.Target.Action(pkt, hook, r, addressEP); v {
case RuleAccept:
continue
case RuleDrop:
@@ -327,21 +420,6 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr
}
}
- // If this connection should be tracked, try to add an entry for it. If
- // traversing the nat table didn't end in adding an entry,
- // maybeInsertNoop will add a no-op entry for the connection. This is
- // needeed when establishing connections so that the SYN/ACK reply to an
- // outgoing SYN is delivered to the correct endpoint rather than being
- // redirected by a prerouting rule.
- //
- // From the iptables documentation: "If there is no rule, a `null'
- // binding is created: this usually does not map the packet, but exists
- // to ensure we don't map another stream over an existing one."
- if shouldTrack {
- it.connections.maybeInsertNoop(pkt, hook)
- }
-
- // Every table returned Accept.
return true
}
@@ -375,30 +453,46 @@ func (it *IPTables) startReaper(interval time.Duration) {
}()
}
-// CheckPackets runs pkts through the rules for hook and returns a map of packets that
-// should not go forward.
+// CheckOutputPackets performs the output hook on the packets.
//
-// Preconditions:
-// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-// * pkt.NetworkHeader is not nil.
+// Returns a map of packets that must be dropped.
//
-// NOTE: unlike the Check API the returned map contains packets that should be
-// dropped.
-func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+// Precondition: The packets' network and transport header must be set.
+func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+ return checkPackets(pkts, func(pkt *PacketBuffer) bool {
+ return it.CheckOutput(pkt, r, outNicName)
+ }, true /* dnat */)
+}
+
+// CheckPostroutingPackets performs the postrouting hook on the packets.
+//
+// Returns a map of packets that must be dropped.
+//
+// Precondition: The packets' network and transport header must be set.
+func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, addressEP AddressableEndpoint, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+ return checkPackets(pkts, func(pkt *PacketBuffer) bool {
+ return it.CheckPostrouting(pkt, r, addressEP, outNicName)
+ }, false /* dnat */)
+}
+
+func checkPackets(pkts PacketBufferList, f func(*PacketBuffer) bool, dnat bool) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- if !pkt.NatDone {
- if ok := it.Check(hook, pkt, r, "", inNicName, outNicName); !ok {
- if drop == nil {
- drop = make(map[*PacketBuffer]struct{})
- }
- drop[pkt] = struct{}{}
+ natDone := &pkt.SNATDone
+ if dnat {
+ natDone = &pkt.DNATDone
+ }
+
+ if ok := f(pkt); !ok {
+ if drop == nil {
+ drop = make(map[*PacketBuffer]struct{})
}
- if pkt.NatDone {
- if natPkts == nil {
- natPkts = make(map[*PacketBuffer]struct{})
- }
- natPkts[pkt] = struct{}{}
+ drop[pkt] = struct{}{}
+ }
+ if *natDone {
+ if natPkts == nil {
+ natPkts = make(map[*PacketBuffer]struct{})
}
+ natPkts[pkt] = struct{}{}
}
}
return drop, natPkts
@@ -407,11 +501,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inN
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict {
+func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
for ruleIdx < len(table.Rules) {
- switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict {
case RuleAccept:
return chainAccept
@@ -428,7 +522,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
ruleIdx++
continue
}
- switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, preroutingAddr, inNicName, outNicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, addressEP, inNicName, outNicName); verdict {
case chainAccept:
return chainAccept
case chainDrop:
@@ -454,7 +548,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) {
+func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// Check whether the packet matches the IP header filter.
@@ -477,16 +571,16 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
}
// All the matchers matched, so run the target.
- return rule.Target.Action(pkt, &it.connections, hook, r, preroutingAddr)
+ return rule.Target.Action(pkt, hook, r, addressEP)
}
// OriginalDst returns the original destination of redirected connections. It
// returns an error if the connection doesn't exist or isn't redirected.
-func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
+func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
it.mu.RLock()
defer it.mu.RUnlock()
if !it.modified {
return "", 0, &tcpip.ErrNotConnected{}
}
- return it.connections.originalDst(epID, netProto)
+ return it.connections.originalDst(epID, netProto, transProto)
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 96cc899bb..ef515bdd2 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -29,7 +29,7 @@ type AcceptTarget struct {
}
// Action implements Target.Action.
-func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*AcceptTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleAccept, 0
}
@@ -40,7 +40,7 @@ type DropTarget struct {
}
// Action implements Target.Action.
-func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*DropTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleDrop, 0
}
@@ -52,7 +52,7 @@ type ErrorTarget struct {
}
// Action implements Target.Action.
-func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ErrorTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
return RuleDrop, 0
}
@@ -67,7 +67,7 @@ type UserChainTarget struct {
}
// Action implements Target.Action.
-func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*UserChainTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
@@ -79,10 +79,49 @@ type ReturnTarget struct {
}
// Action implements Target.Action.
-func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ReturnTarget) Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) {
return RuleReturn, 0
}
+// DNATTarget modifies the destination port/IP of packets.
+type DNATTarget struct {
+ // The new destination address for packets.
+ //
+ // Immutable.
+ Addr tcpip.Address
+
+ // The new destination port for packets.
+ //
+ // Immutable.
+ Port uint16
+
+ // NetworkProtocol is the network protocol the target is used with.
+ //
+ // Immutable.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// Action implements Target.Action.
+func (rt *DNATTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
+ // Sanity check.
+ if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "DNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ rt.NetworkProtocol, pkt.NetworkProtocolNumber))
+ }
+
+ switch hook {
+ case Prerouting, Output:
+ case Input, Forward, Postrouting:
+ panic(fmt.Sprintf("%s not supported for DNAT", hook))
+ default:
+ panic(fmt.Sprintf("%s unrecognized", hook))
+ }
+
+ return natAction(pkt, hook, r, rt.Port, rt.Addr, true /* dnat */)
+
+}
+
// RedirectTarget redirects the packet to this machine by modifying the
// destination port/IP. Outgoing packets are redirected to the loopback device,
// and incoming packets are redirected to the incoming interface (rather than
@@ -97,7 +136,7 @@ type RedirectTarget struct {
}
// Action implements Target.Action.
-func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) {
+func (rt *RedirectTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
// Sanity check.
if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
@@ -105,18 +144,9 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r
rt.NetworkProtocol, pkt.NetworkProtocolNumber))
}
- // Packet is already manipulated.
- if pkt.NatDone {
- return RuleAccept, 0
- }
-
- // Drop the packet if network and transport header are not set.
- if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() {
- return RuleDrop, 0
- }
-
// Change the address to loopback (127.0.0.1 or ::1) in Output and to
// the primary address of the incoming interface in Prerouting.
+ var address tcpip.Address
switch hook {
case Output:
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
@@ -125,48 +155,13 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r
address = header.IPv6Loopback
}
case Prerouting:
- // No-op, as address is already set correctly.
+ // addressEP is expected to be set for the prerouting hook.
+ address = addressEP.MainAddress().Address
default:
panic("redirect target is supported only on output and prerouting hooks")
}
- switch protocol := pkt.TransportProtocolNumber; protocol {
- case header.UDPProtocolNumber:
- udpHeader := header.UDP(pkt.TransportHeader().View())
-
- if hook == Output {
- // Only calculate the checksum if offloading isn't supported.
- requiresChecksum := r.RequiresTXTransportChecksum()
- rewritePacket(
- pkt.Network(),
- udpHeader,
- false, /* updateSRCFields */
- requiresChecksum,
- requiresChecksum,
- rt.Port,
- address,
- )
- } else {
- udpHeader.SetDestinationPort(rt.Port)
- }
-
- pkt.NatDone = true
- case header.TCPProtocolNumber:
- if ct == nil {
- return RuleAccept, 0
- }
-
- // Set up conection for matching NAT rule. Only the first
- // packet of the connection comes here. Other packets will be
- // manipulated in connection tracking.
- if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil {
- ct.handlePacket(pkt, hook, r)
- }
- default:
- return RuleDrop, 0
- }
-
- return RuleAccept, 0
+ return natAction(pkt, hook, r, rt.Port, address, true /* dnat */)
}
// SNATTarget modifies the source port/IP in the outgoing packets.
@@ -179,8 +174,36 @@ type SNATTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
+func natAction(pkt *PacketBuffer, hook Hook, r *Route, port uint16, address tcpip.Address, dnat bool) (RuleVerdict, int) {
+ // Drop the packet if network and transport header are not set.
+ if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() {
+ return RuleDrop, 0
+ }
+
+ t := pkt.tuple
+ if t == nil {
+ return RuleDrop, 0
+ }
+
+ // TODO(https://gvisor.dev/issue/5773): If the port is in use, pick a
+ // different port.
+ if port == 0 {
+ switch protocol := pkt.TransportProtocolNumber; protocol {
+ case header.UDPProtocolNumber:
+ port = header.UDP(pkt.TransportHeader().View()).SourcePort()
+ case header.TCPProtocolNumber:
+ port = header.TCP(pkt.TransportHeader().View()).SourcePort()
+ default:
+ panic(fmt.Sprintf("unsupported transport protocol = %d", pkt.TransportProtocolNumber))
+ }
+ }
+
+ t.conn.performNAT(pkt, hook, r, port, address, dnat)
+ return RuleAccept, 0
+}
+
// Action implements Target.Action.
-func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) {
+func (st *SNATTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) {
// Sanity check.
if st.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
@@ -188,16 +211,6 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
st.NetworkProtocol, pkt.NetworkProtocolNumber))
}
- // Packet is already manipulated.
- if pkt.NatDone {
- return RuleAccept, 0
- }
-
- // Drop the packet if network and transport header are not set.
- if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() {
- return RuleDrop, 0
- }
-
switch hook {
case Postrouting, Input:
case Prerouting, Output, Forward:
@@ -206,37 +219,43 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
panic(fmt.Sprintf("%s unrecognized", hook))
}
- switch protocol := pkt.TransportProtocolNumber; protocol {
- case header.UDPProtocolNumber:
- // Only calculate the checksum if offloading isn't supported.
- requiresChecksum := r.RequiresTXTransportChecksum()
- rewritePacket(
- pkt.Network(),
- header.UDP(pkt.TransportHeader().View()),
- true, /* updateSRCFields */
- requiresChecksum,
- requiresChecksum,
- st.Port,
- st.Addr,
- )
-
- pkt.NatDone = true
- case header.TCPProtocolNumber:
- if ct == nil {
- return RuleAccept, 0
- }
+ return natAction(pkt, hook, r, st.Port, st.Addr, false /* dnat */)
+}
- // Set up conection for matching NAT rule. Only the first
- // packet of the connection comes here. Other packets will be
- // manipulated in connection tracking.
- if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil {
- ct.handlePacket(pkt, hook, r)
- }
+// MasqueradeTarget modifies the source port/IP in the outgoing packets.
+type MasqueradeTarget struct {
+ // NetworkProtocol is the network protocol the target is used with. It
+ // is immutable.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// Action implements Target.Action.
+func (mt *MasqueradeTarget) Action(pkt *PacketBuffer, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
+ // Sanity check.
+ if mt.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "MasqueradeTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ mt.NetworkProtocol, pkt.NetworkProtocolNumber))
+ }
+
+ switch hook {
+ case Postrouting:
+ case Prerouting, Input, Forward, Output:
+ panic(fmt.Sprintf("masquerade target is supported only on postrouting hook; hook = %d", hook))
default:
+ panic(fmt.Sprintf("%s unrecognized", hook))
+ }
+
+ // addressEP is expected to be set for the postrouting hook.
+ ep := addressEP.AcquireOutgoingPrimaryAddress(pkt.Network().DestinationAddress(), false /* allowExpired */)
+ if ep == nil {
+ // No address exists that we can use as a source address.
return RuleDrop, 0
}
- return RuleAccept, 0
+ address := ep.AddressWithPrefix().Address
+ ep.DecRef()
+ return natAction(pkt, hook, r, 0 /* port */, address, false /* dnat */)
}
func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) {
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 66e5f22ac..b22024667 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -81,17 +81,6 @@ const (
//
// +stateify savable
type IPTables struct {
- // mu protects v4Tables, v6Tables, and modified.
- mu sync.RWMutex
- // v4Tables and v6tables map tableIDs to tables. They hold builtin
- // tables only, not user tables. mu must be locked for accessing.
- v4Tables [NumTables]Table
- v6Tables [NumTables]Table
- // modified is whether tables have been modified at least once. It is
- // used to elide the iptables performance overhead for workloads that
- // don't utilize iptables.
- modified bool
-
// priorities maps each hook to a list of table names. The order of the
// list is the order in which each table should be visited for that
// hook. It is immutable.
@@ -101,6 +90,21 @@ type IPTables struct {
// reaperDone can be signaled to stop the reaper goroutine.
reaperDone chan struct{}
+
+ mu sync.RWMutex
+ // v4Tables and v6tables map tableIDs to tables. They hold builtin
+ // tables only, not user tables.
+ //
+ // +checklocks:mu
+ v4Tables [NumTables]Table
+ // +checklocks:mu
+ v6Tables [NumTables]Table
+ // modified is whether tables have been modified at least once. It is
+ // used to elide the iptables performance overhead for workloads that
+ // don't utilize iptables.
+ //
+ // +checklocks:mu
+ modified bool
}
// VisitTargets traverses all the targets of all tables and replaces each with
@@ -352,5 +356,5 @@ type Target interface {
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
- Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int)
+ Action(*PacketBuffer, Hook, *Route, AddressableEndpoint) (RuleVerdict, int)
}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 4d5431da1..40b33b6b5 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -333,8 +333,12 @@ func TestDADDisabled(t *testing.T) {
Address: addr1,
PrefixLen: defaultPrefixLen,
}
- if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err)
}
// Should get the address immediately since we should not have performed
@@ -379,12 +383,15 @@ func TestDADResolveLoopback(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- addrWithPrefix := tcpip.AddressWithPrefix{
- Address: addr1,
- PrefixLen: defaultPrefixLen,
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr1,
+ PrefixLen: defaultPrefixLen,
+ },
}
- if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err)
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
@@ -517,8 +524,12 @@ func TestDADResolve(t *testing.T) {
Address: addr1,
PrefixLen: defaultPrefixLen,
}
- if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err)
}
// Make sure the address does not resolve before the resolution time has
@@ -740,8 +751,12 @@ func TestDADFail(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// Address should not be considered bound to the NIC yet
@@ -778,8 +793,8 @@ func TestDADFail(t *testing.T) {
// Attempting to add the address again should not fail if the address's
// state was cleaned up when DAD failed.
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
})
}
@@ -851,8 +866,12 @@ func TestDADStop(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
@@ -975,17 +994,29 @@ func TestSetNDPConfigurations(t *testing.T) {
// Add addresses for each NIC.
addrWithPrefix1 := tcpip.AddressWithPrefix{Address: addr1, PrefixLen: defaultPrefixLen}
- if err := s.AddAddressWithPrefix(nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addrWithPrefix1, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix1,
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID1, protocolAddr1, err)
}
addrWithPrefix2 := tcpip.AddressWithPrefix{Address: addr2, PrefixLen: defaultPrefixLen}
- if err := s.AddAddressWithPrefix(nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addrWithPrefix2, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix2,
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID2, protocolAddr2, err)
}
expectDADEvent(nicID2, addr2)
addrWithPrefix3 := tcpip.AddressWithPrefix{Address: addr3, PrefixLen: defaultPrefixLen}
- if err := s.AddAddressWithPrefix(nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addrWithPrefix3, err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addrWithPrefix3,
+ }
+ if err := s.AddProtocolAddress(nicID3, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID3, protocolAddr3, err)
}
expectDADEvent(nicID3, addr3)
@@ -2788,8 +2819,12 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
continue
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, test.addrs[j].Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, test.addrs[j].Address, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: test.addrs[j].Address.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
manuallyAssignedAddresses[test.addrs[j].Address] = struct{}{}
@@ -3644,8 +3679,9 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) {
Protocol: header.IPv6ProtocolNumber,
AddressWithPrefix: addr2,
}
- if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
+ properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ if err := s.AddProtocolAddress(nicID, protoAddr2, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v) = %s", nicID, protoAddr2, properties, err)
}
// addr2 should be more preferred now since it is at the front of the primary
// list.
@@ -3733,8 +3769,9 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
}
// Add the address as a static address before SLAAC tries to add it.
- if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err)
+ protocolAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}) = %s", protocolAddr, err)
}
if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
t.Fatalf("Should have %s in the list of addresses", addr1)
@@ -4073,8 +4110,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
// Attempting to add the address manually should not fail if the
// address's state was cleaned up when DAD failed.
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
if err := s.RemoveAddress(nicID, addr.Address); err != nil {
t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err)
@@ -5362,8 +5403,12 @@ func TestRouterSolicitation(t *testing.T) {
}
if addr := test.nicAddr; addr != "" {
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index a796942ab..e251e3b24 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -97,6 +97,8 @@ type packetEndpointList struct {
mu sync.RWMutex
// eps is protected by mu, but the contained PacketEndpoint values are not.
+ //
+ // +checklocks:mu
eps []PacketEndpoint
}
@@ -117,6 +119,12 @@ func (p *packetEndpointList) remove(ep PacketEndpoint) {
}
}
+func (p *packetEndpointList) len() int {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ return len(p.eps)
+}
+
// forEach calls fn with each endpoints in p while holding the read lock on p.
func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) {
p.mu.RLock()
@@ -157,14 +165,8 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
resolutionRequired := ep.Capabilities()&CapabilityResolutionRequired != 0
- // Register supported packet and network endpoint protocols.
- for _, netProto := range header.Ethertypes {
- nic.packetEPs.eps[netProto] = new(packetEndpointList)
- }
for _, netProto := range stack.networkProtocols {
netNum := netProto.Number()
- nic.packetEPs.eps[netNum] = new(packetEndpointList)
-
netEP := netProto.NewEndpoint(nic, nic)
nic.networkEndpoints[netNum] = netEP
@@ -514,7 +516,7 @@ func (n *nic) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber,
// addAddress adds a new address to n, so that it starts accepting packets
// targeted at the given address (and network protocol).
-func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error {
+func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, properties AddressProperties) tcpip.Error {
ep, ok := n.networkEndpoints[protocolAddress.Protocol]
if !ok {
return &tcpip.ErrUnknownProtocol{}
@@ -525,7 +527,7 @@ func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo
return &tcpip.ErrNotSupported{}
}
- addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */)
+ addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, properties)
if err == nil {
// We have no need for the address endpoint.
addressEndpoint.DecRef()
@@ -831,24 +833,9 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt
transProto := state.proto
- // TransportHeader is empty only when pkt is an ICMP packet or was reassembled
- // from fragments.
if pkt.TransportHeader().View().IsEmpty() {
- // ICMP packets don't have their TransportHeader fields set yet, parse it
- // here. See icmp/protocol.go:protocol.Parse for a full explanation.
- if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
- // ICMP packets may be longer, but until icmp.Parse is implemented, here
- // we parse it using the minimum size.
- if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok {
- n.stats.malformedL4RcvdPackets.Increment()
- // We consider a malformed transport packet handled because there is
- // nothing the caller can do.
- return TransportPacketHandled
- }
- } else if !transProto.Parse(pkt) {
- n.stats.malformedL4RcvdPackets.Increment()
- return TransportPacketHandled
- }
+ n.stats.malformedL4RcvdPackets.Increment()
+ return TransportPacketHandled
}
srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View())
@@ -974,7 +961,8 @@ func (n *nic) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa
eps, ok := n.packetEPs.eps[netProto]
if !ok {
- return &tcpip.ErrNotSupported{}
+ eps = new(packetEndpointList)
+ n.packetEPs.eps[netProto] = eps
}
eps.add(ep)
@@ -990,6 +978,9 @@ func (n *nic) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep
return
}
eps.remove(ep)
+ if eps.len() == 0 {
+ delete(n.packetEPs.eps, netProto)
+ }
}
// isValidForOutgoing returns true if the endpoint can be used to send out a
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 5cb342f78..c8ad93f29 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -127,11 +127,6 @@ func (*testIPv6Protocol) MinimumPacketSize() int {
return header.IPv6MinimumSize
}
-// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen.
-func (*testIPv6Protocol) DefaultPrefixLen() int {
- return header.IPv6AddressSize * 8
-}
-
// ParseAddresses implements NetworkProtocol.ParseAddresses.
func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv6(v)
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 29c22bfd4..c4a4bbd22 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -126,9 +126,13 @@ type PacketBuffer struct {
EgressRoute RouteInfo
GSOOptions GSO
- // NatDone indicates if the packet has been manipulated as per NAT
- // iptables rule.
- NatDone bool
+ // SNATDone indicates if the packet's source has been manipulated as per
+ // iptables NAT table.
+ SNATDone bool
+
+ // DNATDone indicates if the packet's destination has been manipulated as per
+ // iptables NAT table.
+ DNATDone bool
// PktType indicates the SockAddrLink.PacketType of the packet as defined in
// https://www.man7.org/linux/man-pages/man7/packet.7.html.
@@ -143,6 +147,8 @@ type PacketBuffer struct {
// NetworkPacketInfo holds an incoming packet's network-layer information.
NetworkPacketInfo NetworkPacketInfo
+
+ tuple *tuple
}
// NewPacketBuffer creates a new PacketBuffer with opts.
@@ -296,12 +302,14 @@ func (pk *PacketBuffer) Clone() *PacketBuffer {
Owner: pk.Owner,
GSOOptions: pk.GSOOptions,
NetworkProtocolNumber: pk.NetworkProtocolNumber,
- NatDone: pk.NatDone,
+ DNATDone: pk.DNATDone,
+ SNATDone: pk.SNATDone,
TransportProtocolNumber: pk.TransportProtocolNumber,
PktType: pk.PktType,
NICID: pk.NICID,
RXTransportChecksumValidated: pk.RXTransportChecksumValidated,
NetworkPacketInfo: pk.NetworkPacketInfo,
+ tuple: pk.tuple,
}
}
@@ -329,15 +337,41 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
buf: pk.buf.Clone(),
// Treat unfilled header portion as reserved.
reserved: pk.AvailableHeaderBytes(),
+ tuple: pk.tuple,
+ }
+ return newPk
+}
+
+// DeepCopyForForwarding creates a deep copy of the packet buffer for
+// forwarding.
+//
+// The returned packet buffer will have the network and transport headers
+// set if the original packet buffer did.
+func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBuffer {
+ newPk := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: reservedHeaderBytes,
+ Data: PayloadSince(pk.NetworkHeader()).ToVectorisedView(),
+ IsForwardedPacket: true,
+ })
+
+ {
+ consumeBytes := pk.NetworkHeader().View().Size()
+ if _, consumed := newPk.NetworkHeader().Consume(consumeBytes); !consumed {
+ panic(fmt.Sprintf("expected to consume network header %d bytes from new packet", consumeBytes))
+ }
+ newPk.NetworkProtocolNumber = pk.NetworkProtocolNumber
}
- // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to
- // maintain this flag in the packet. Currently conntrack needs this flag to
- // tell if a noop connection should be inserted at Input hook. Once conntrack
- // redefines the manipulation field as mutable, we won't need the special noop
- // connection.
- if pk.NatDone {
- newPk.NatDone = true
+
+ {
+ consumeBytes := pk.TransportHeader().View().Size()
+ if _, consumed := newPk.TransportHeader().Consume(consumeBytes); !consumed {
+ panic(fmt.Sprintf("expected to consume transport header %d bytes from new packet", consumeBytes))
+ }
+ newPk.TransportProtocolNumber = pk.TransportProtocolNumber
}
+
+ newPk.tuple = pk.tuple
+
return newPk
}
@@ -389,13 +423,14 @@ func (d PacketData) PullUp(size int) (tcpipbuffer.View, bool) {
return d.pk.buf.PullUp(d.pk.dataOffset(), size)
}
-// DeleteFront removes count from the beginning of d. It panics if count >
-// d.Size(). All backing storage references after the front of the d are
-// invalidated.
-func (d PacketData) DeleteFront(count int) {
- if !d.pk.buf.Remove(d.pk.dataOffset(), count) {
- panic("count > d.Size()")
+// Consume is the same as PullUp except that is additionally consumes the
+// returned bytes. Subsequent PullUp or Consume will not return these bytes.
+func (d PacketData) Consume(size int) (tcpipbuffer.View, bool) {
+ v, ok := d.PullUp(size)
+ if ok {
+ d.pk.consumed += size
}
+ return v, ok
}
// CapLength reduces d to at most length bytes.
diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go
index 87b023445..c376ed1a1 100644
--- a/pkg/tcpip/stack/packet_buffer_test.go
+++ b/pkg/tcpip/stack/packet_buffer_test.go
@@ -123,32 +123,6 @@ func TestPacketHeaderPush(t *testing.T) {
}
}
-func TestPacketBufferClone(t *testing.T) {
- data := concatViews(makeView(20), makeView(30), makeView(40))
- pk := NewPacketBuffer(PacketBufferOptions{
- // Make a copy of data to make sure our truth data won't be taint by
- // PacketBuffer.
- Data: buffer.NewViewFromBytes(data).ToVectorisedView(),
- })
-
- bytesToDelete := 30
- originalSize := data.Size()
-
- clonedPks := []*PacketBuffer{
- pk.Clone(),
- pk.CloneToInbound(),
- }
- pk.Data().DeleteFront(bytesToDelete)
- if got, want := pk.Data().Size(), originalSize-bytesToDelete; got != want {
- t.Errorf("original packet was not changed: size expected = %d, got = %d", want, got)
- }
- for _, clonedPk := range clonedPks {
- if got := clonedPk.Data().Size(); got != originalSize {
- t.Errorf("cloned packet should not be modified: expected size = %d, got = %d", originalSize, got)
- }
- }
-}
-
func TestPacketHeaderConsume(t *testing.T) {
for _, test := range []struct {
name string
@@ -461,11 +435,17 @@ func TestPacketBufferData(t *testing.T) {
}
})
- // DeleteFront
+ // Consume.
for _, n := range []int{1, len(tc.data)} {
- t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) {
+ t.Run(fmt.Sprintf("Consume%d", n), func(t *testing.T) {
pkt := tc.makePkt(t)
- pkt.Data().DeleteFront(n)
+ v, ok := pkt.Data().Consume(n)
+ if !ok {
+ t.Fatalf("Consume failed")
+ }
+ if want := []byte(tc.data)[:n]; !bytes.Equal(v, want) {
+ t.Fatalf("pkt.Data().Consume(n) = 0x%x, want 0x%x", v, want)
+ }
checkData(t, pkt, []byte(tc.data)[n:])
})
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 113baaaae..31b3a554d 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -318,8 +318,7 @@ type PrimaryEndpointBehavior int
const (
// CanBePrimaryEndpoint indicates the endpoint can be used as a primary
- // endpoint for new connections with no local address. This is the
- // default when calling NIC.AddAddress.
+ // endpoint for new connections with no local address.
CanBePrimaryEndpoint PrimaryEndpointBehavior = iota
// FirstPrimaryEndpoint indicates the endpoint should be the first
@@ -332,6 +331,19 @@ const (
NeverPrimaryEndpoint
)
+func (peb PrimaryEndpointBehavior) String() string {
+ switch peb {
+ case CanBePrimaryEndpoint:
+ return "CanBePrimaryEndpoint"
+ case FirstPrimaryEndpoint:
+ return "FirstPrimaryEndpoint"
+ case NeverPrimaryEndpoint:
+ return "NeverPrimaryEndpoint"
+ default:
+ panic(fmt.Sprintf("unknown primary endpoint behavior: %d", peb))
+ }
+}
+
// AddressConfigType is the method used to add an address.
type AddressConfigType int
@@ -351,6 +363,14 @@ const (
AddressConfigSlaacTemp
)
+// AddressProperties contains additional properties that can be configured when
+// adding an address.
+type AddressProperties struct {
+ PEB PrimaryEndpointBehavior
+ ConfigType AddressConfigType
+ Deprecated bool
+}
+
// AssignableAddressEndpoint is a reference counted address endpoint that may be
// assigned to a NetworkEndpoint.
type AssignableAddressEndpoint interface {
@@ -457,7 +477,7 @@ type AddressableEndpoint interface {
// Returns *tcpip.ErrDuplicateAddress if the address exists.
//
// Acquires and returns the AddressEndpoint for the added address.
- AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error)
+ AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error)
// RemovePermanentAddress removes the passed address if it is a permanent
// address.
@@ -685,9 +705,6 @@ type NetworkProtocol interface {
// than this targeted at this protocol.
MinimumPacketSize() int
- // DefaultPrefixLen returns the protocol's default prefix length.
- DefaultPrefixLen() int
-
// ParseAddresses returns the source and destination addresses stored in a
// packet of this protocol.
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index cb741e540..a05fd7036 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -238,7 +238,7 @@ type Options struct {
// DefaultIPTables is an optional iptables rules constructor that is called
// if IPTables is nil. If both fields are nil, iptables will allow all
// traffic.
- DefaultIPTables func(uint32) *IPTables
+ DefaultIPTables func(seed uint32, clock tcpip.Clock) *IPTables
// SecureRNG is a cryptographically secure random number generator.
SecureRNG io.Reader
@@ -358,7 +358,7 @@ func New(opts Options) *Stack {
if opts.DefaultIPTables == nil {
opts.DefaultIPTables = DefaultTables
}
- opts.IPTables = opts.DefaultIPTables(seed)
+ opts.IPTables = opts.DefaultIPTables(seed, clock)
}
opts.NUDConfigs.resetInvalidFields()
@@ -375,7 +375,7 @@ func New(opts Options) *Stack {
stats: opts.Stats.FillIn(),
handleLocal: opts.HandleLocal,
tables: opts.IPTables,
- icmpRateLimiter: NewICMPRateLimiter(),
+ icmpRateLimiter: NewICMPRateLimiter(clock),
seed: seed,
nudConfigs: opts.NUDConfigs,
uniqueIDGenerator: opts.UniqueID,
@@ -916,46 +916,9 @@ type NICStateFlags struct {
Loopback bool
}
-// AddAddress adds a new network-layer address to the specified NIC.
-func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error {
- return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint)
-}
-
-// AddAddressWithPrefix is the same as AddAddress, but allows you to specify
-// the address prefix.
-func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) tcpip.Error {
- ap := tcpip.ProtocolAddress{
- Protocol: protocol,
- AddressWithPrefix: addr,
- }
- return s.AddProtocolAddressWithOptions(id, ap, CanBePrimaryEndpoint)
-}
-
-// AddProtocolAddress adds a new network-layer protocol address to the
-// specified NIC.
-func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) tcpip.Error {
- return s.AddProtocolAddressWithOptions(id, protocolAddress, CanBePrimaryEndpoint)
-}
-
-// AddAddressWithOptions is the same as AddAddress, but allows you to specify
-// whether the new endpoint can be primary or not.
-func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) tcpip.Error {
- netProto, ok := s.networkProtocols[protocol]
- if !ok {
- return &tcpip.ErrUnknownProtocol{}
- }
- return s.AddProtocolAddressWithOptions(id, tcpip.ProtocolAddress{
- Protocol: protocol,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: netProto.DefaultPrefixLen(),
- },
- }, peb)
-}
-
-// AddProtocolAddressWithOptions is the same as AddProtocolAddress, but allows
-// you to specify whether the new endpoint can be primary or not.
-func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error {
+// AddProtocolAddress adds an address to the specified NIC, possibly with extra
+// properties.
+func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, properties AddressProperties) tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -964,7 +927,7 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc
return &tcpip.ErrUnknownNICID{}
}
- return nic.addAddress(protocolAddress, peb)
+ return nic.addAddress(protocolAddress, properties)
}
// RemoveAddress removes an existing network-layer address from the specified
@@ -1902,12 +1865,6 @@ const (
// ParsePacketBufferTransport parses the provided packet buffer's transport
// header.
func (s *Stack) ParsePacketBufferTransport(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) ParseResult {
- // ICMP packets don't have their TransportHeader fields set yet, parse it
- // here. See icmp/protocol.go:protocol.Parse for a full explanation.
- if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
- return ParsedOK
- }
-
pkt.TransportProtocolNumber = protocol
// Parse the transport header if present.
state, ok := s.transportProtocols[protocol]
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 3089c0ef4..f5a35eac4 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -139,18 +139,15 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Handle control packets.
if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
- hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen)
+ hdr, ok := pkt.Data().Consume(fakeNetHeaderLen)
if !ok {
return
}
- // DeleteFront invalidates slices. Make a copy before trimming.
- nb := append([]byte(nil), hdr...)
- pkt.Data().DeleteFront(fakeNetHeaderLen)
f.dispatcher.DeliverTransportError(
- tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]),
- tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]),
+ tcpip.Address(hdr[srcAddrOffset:srcAddrOffset+1]),
+ tcpip.Address(hdr[dstAddrOffset:dstAddrOffset+1]),
fakeNetNumber,
- tcpip.TransportProtocolNumber(nb[protocolNumberOffset]),
+ tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]),
// Nothing checks the error.
nil, /* transport error */
pkt,
@@ -158,8 +155,18 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
return
}
+ transProtoNum := tcpip.TransportProtocolNumber(netHdr[protocolNumberOffset])
+ switch err := f.proto.stack.ParsePacketBufferTransport(transProtoNum, pkt); err {
+ case stack.ParsedOK:
+ case stack.UnknownTransportProtocol, stack.TransportLayerParseError:
+ // The transport layer will handle unknown protocols and transport layer
+ // parsing errors.
+ default:
+ panic(fmt.Sprintf("unexpected error parsing transport header = %d", err))
+ }
+
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
+ f.dispatcher.DeliverTransportPacket(transProtoNum, pkt)
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -221,6 +228,8 @@ func (*fakeNetworkEndpointStats) IsNetworkEndpointStats() {}
// number of packets sent and received via endpoints of this protocol. The index
// where packets are added is given by the packet's destination address MOD 10.
type fakeNetworkProtocol struct {
+ stack *stack.Stack
+
packetCount [10]int
sendPacketCount [10]int
defaultTTL uint8
@@ -234,10 +243,6 @@ func (*fakeNetworkProtocol) MinimumPacketSize() int {
return fakeNetHeaderLen
}
-func (*fakeNetworkProtocol) DefaultPrefixLen() int {
- return fakeDefaultPrefixLen
-}
-
func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int {
return f.packetCount[int(intfAddr)%len(f.packetCount)]
}
@@ -306,8 +311,8 @@ func (f *fakeNetworkEndpoint) SetForwarding(v bool) {
f.mu.forwarding = v
}
-func fakeNetFactory(*stack.Stack) stack.NetworkProtocol {
- return &fakeNetworkProtocol{}
+func fakeNetFactory(s *stack.Stack) stack.NetworkProtocol {
+ return &fakeNetworkProtocol{stack: s}
}
// linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify
@@ -349,12 +354,26 @@ func TestNetworkReceive(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr2, err)
}
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
@@ -517,8 +536,15 @@ func TestNetworkSend(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Make sure that the link-layer endpoint received the outbound packet.
@@ -538,12 +564,26 @@ func TestNetworkSendMultiRoute(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x03",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err)
}
ep2 := channel.New(10, defaultMTU, "")
@@ -551,12 +591,26 @@ func TestNetworkSendMultiRoute(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr4 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x04",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err)
}
// Set a route table that sends all packets with odd destination
@@ -812,8 +866,15 @@ func TestRouteWithDownNIC(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr1,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err)
}
ep2 := channel.New(1, defaultMTU, "")
@@ -821,8 +882,15 @@ func TestRouteWithDownNIC(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr2,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err)
}
// Set a route table that sends all packets with odd destination
@@ -978,12 +1046,26 @@ func TestRoutes(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x03",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err)
}
ep2 := channel.New(10, defaultMTU, "")
@@ -991,12 +1073,26 @@ func TestRoutes(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x02",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr4 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x04",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err)
}
// Set a route table that sends all packets with odd destination
@@ -1058,8 +1154,15 @@ func TestAddressRemoval(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
@@ -1108,8 +1211,15 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
- if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
@@ -1242,8 +1352,15 @@ func TestEndpointExpiration(t *testing.T) {
// 2. Add Address, everything should work.
//-----------------------
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
@@ -1270,8 +1387,8 @@ func TestEndpointExpiration(t *testing.T) {
// 4. Add Address back, everything should work again.
//-----------------------
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
@@ -1310,8 +1427,8 @@ func TestEndpointExpiration(t *testing.T) {
// 7. Add Address back, everything should work again.
//-----------------------
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
@@ -1453,8 +1570,15 @@ func TestExternalSendWithHandleLocal(t *testing.T) {
if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}})
@@ -1510,8 +1634,15 @@ func TestSpoofingWithAddress(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
@@ -1633,8 +1764,8 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
}
protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}}
- if err := s.AddProtocolAddress(1, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err)
+ if err := s.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", protoAddr, err)
}
r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
if err != nil {
@@ -1678,13 +1809,13 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
t.Fatalf("CreateNIC failed: %s", err)
}
nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr}
- if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err)
+ if err := s.AddProtocolAddress(1, nic1ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", nic1ProtoAddr, err)
}
nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr}
- if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil {
- t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err)
+ if err := s.AddProtocolAddress(2, nic2ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(2, %+v, {}) failed: %s", nic2ProtoAddr, err)
}
// Set the initial route table.
@@ -1726,7 +1857,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
// 2. Case: Having an explicit route for broadcast will select that one.
rt = append(
[]tcpip.Route{
- {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1},
+ {Destination: header.IPv4Broadcast.WithPrefix().Subnet(), NIC: 1},
},
rt...,
)
@@ -1808,8 +1939,15 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want)
}
- if err := s.AddAddress(1, fakeNetNumber, anyAddr); err != nil {
- t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, anyAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: anyAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded {
@@ -1886,22 +2024,27 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
// Add an address and in case of a primary one include a
// prefixLen.
address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen))
+ properties := stack.AddressProperties{PEB: behavior}
if behavior == stack.CanBePrimaryEndpoint {
protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: address,
- PrefixLen: addrLen * 8,
- },
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: address.WithPrefix(),
}
- if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, protocolAddress, behavior, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err)
}
// Remember the address/prefix.
primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{}
} else {
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s:", nicID, fakeNetNumber, address, behavior, err)
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err)
}
}
}
@@ -1996,8 +2139,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) {
PrefixLen: tc.prefixLen,
},
}
- if err := s.AddProtocolAddress(1, protocolAddress); err != nil {
- t.Fatal("AddProtocolAddress failed:", err)
+ if err := s.AddProtocolAddress(1, protocolAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", protocolAddress, err)
}
// Check that we get the right initial address and prefix length.
@@ -2047,33 +2190,6 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto
}
}
-func TestAddAddress(t *testing.T) {
- const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- var addrGen addressGenerator
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2)
- for _, addrLen := range []int{4, 16} {
- address := addrGen.next(addrLen)
- if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil {
- t.Fatalf("AddAddress(address=%s) failed: %s", address, err)
- }
- expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen},
- })
- }
-
- gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
-}
-
func TestAddProtocolAddress(t *testing.T) {
const nicID = 1
s := stack.New(stack.Options{
@@ -2084,96 +2200,43 @@ func TestAddProtocolAddress(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- var addrGen addressGenerator
- addrLenRange := []int{4, 16}
- prefixLenRange := []int{8, 13, 20, 32}
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange))
- for _, addrLen := range addrLenRange {
- for _, prefixLen := range prefixLenRange {
- protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addrGen.next(addrLen),
- PrefixLen: prefixLen,
- },
- }
- if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil {
- t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err)
- }
- expectedAddresses = append(expectedAddresses, protocolAddress)
- }
- }
-
- gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
-}
-
-func TestAddAddressWithOptions(t *testing.T) {
- const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
addrLenRange := []int{4, 16}
behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint}
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange))
+ configTypeRange := []stack.AddressConfigType{stack.AddressConfigStatic, stack.AddressConfigSlaac, stack.AddressConfigSlaacTemp}
+ deprecatedRange := []bool{false, true}
+ wantAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)*len(configTypeRange)*len(deprecatedRange))
var addrGen addressGenerator
for _, addrLen := range addrLenRange {
for _, behavior := range behaviorRange {
- address := addrGen.next(addrLen)
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil {
- t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err)
- }
- expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen},
- })
- }
- }
-
- gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
-}
-
-func TestAddProtocolAddressWithOptions(t *testing.T) {
- const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- addrLenRange := []int{4, 16}
- prefixLenRange := []int{8, 13, 20, 32}
- behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint}
- expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)*len(behaviorRange))
- var addrGen addressGenerator
- for _, addrLen := range addrLenRange {
- for _, prefixLen := range prefixLenRange {
- for _, behavior := range behaviorRange {
- protocolAddress := tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addrGen.next(addrLen),
- PrefixLen: prefixLen,
- },
- }
- if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err)
+ for _, configType := range configTypeRange {
+ for _, deprecated := range deprecatedRange {
+ address := addrGen.next(addrLen)
+ properties := stack.AddressProperties{
+ PEB: behavior,
+ ConfigType: configType,
+ Deprecated: deprecated,
+ }
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v) failed: %s", nicID, protocolAddr, properties, err)
+ }
+ wantAddresses = append(wantAddresses, tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen},
+ })
}
- expectedAddresses = append(expectedAddresses, protocolAddress)
}
}
}
gotAddresses := s.AllAddresses()[nicID]
- verifyAddresses(t, expectedAddresses, gotAddresses)
+ verifyAddresses(t, wantAddresses, gotAddresses)
}
func TestCreateNICWithOptions(t *testing.T) {
@@ -2290,8 +2353,15 @@ func TestNICStats(t *testing.T) {
if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed: ", err)
}
- if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: nic.addr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicid, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicid, protocolAddr, err)
}
{
@@ -2735,8 +2805,16 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// be returned by a call to GetMainNICAddress;
// else, it should.
const address1 = tcpip.Address("\x01")
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, pi); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err)
+ properties := stack.AddressProperties{PEB: pi}
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address1,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr, properties, err)
}
addr, err := s.GetMainNICAddress(nicID, fakeNetNumber)
if err != nil {
@@ -2785,16 +2863,31 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// Add some other address with peb set to
// FirstPrimaryEndpoint.
const address3 = tcpip.Address("\x03")
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint, err)
-
+ protocolAddr3 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address3,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ properties = stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ if err := s.AddProtocolAddress(nicID, protocolAddr3, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr3, properties, err)
}
// Add back the address we removed earlier and
// make sure the new peb was respected.
// (The address should just be promoted now).
- if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, ps); err != nil {
- t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address1,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ properties = stack.AddressProperties{PEB: ps}
+ if err := s.AddProtocolAddress(nicID, protocolAddr1, properties); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr1, properties, err)
}
var primaryAddrs []tcpip.Address
for _, pa := range s.NICInfo()[nicID].ProtocolAddresses {
@@ -3096,8 +3189,12 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
}
for _, a := range test.nicAddrs {
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil {
- t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: a.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
@@ -3203,8 +3300,12 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: addr1.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
// The NIC should have joined addr1's solicited node multicast address.
@@ -3359,8 +3460,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
PrefixLen: 128,
},
}
- if err := s.AddProtocolAddress(nicID, addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err)
+ if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
}
// Address should be in the list of all addresses.
@@ -3687,8 +3788,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
if err := s.CreateNIC(nicID1, ep); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
+ if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err)
}
s.SetRouteTable(test.routes)
@@ -3750,8 +3851,8 @@ func TestResolveWith(t *testing.T) {
PrefixLen: 24,
},
}
- if err := s.AddProtocolAddress(nicID, addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, addr, err)
+ if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err)
}
s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}})
@@ -3792,8 +3893,15 @@ func TestRouteReleaseAfterAddrRemoval(t *testing.T) {
if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: localAddr,
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
@@ -3881,8 +3989,8 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) {
PrefixLen: 8,
},
}
- if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddress, err)
}
// Check that we get the right initial address and prefix length.
@@ -3990,44 +4098,44 @@ func TestFindRouteWithForwarding(t *testing.T) {
)
type netCfg struct {
- proto tcpip.NetworkProtocolNumber
- factory stack.NetworkProtocolFactory
- nic1Addr tcpip.Address
- nic2Addr tcpip.Address
- remoteAddr tcpip.Address
+ proto tcpip.NetworkProtocolNumber
+ factory stack.NetworkProtocolFactory
+ nic1AddrWithPrefix tcpip.AddressWithPrefix
+ nic2AddrWithPrefix tcpip.AddressWithPrefix
+ remoteAddr tcpip.Address
}
fakeNetCfg := netCfg{
- proto: fakeNetNumber,
- factory: fakeNetFactory,
- nic1Addr: nic1Addr,
- nic2Addr: nic2Addr,
- remoteAddr: remoteAddr,
+ proto: fakeNetNumber,
+ factory: fakeNetFactory,
+ nic1AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic1Addr, PrefixLen: fakeDefaultPrefixLen},
+ nic2AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic2Addr, PrefixLen: fakeDefaultPrefixLen},
+ remoteAddr: remoteAddr,
}
globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16())
globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16())
ipv6LinkLocalNIC1WithGlobalRemote := netCfg{
- proto: ipv6.ProtocolNumber,
- factory: ipv6.NewProtocol,
- nic1Addr: llAddr1,
- nic2Addr: globalIPv6Addr2,
- remoteAddr: globalIPv6Addr1,
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1AddrWithPrefix: llAddr1.WithPrefix(),
+ nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(),
+ remoteAddr: globalIPv6Addr1,
}
ipv6GlobalNIC1WithLinkLocalRemote := netCfg{
- proto: ipv6.ProtocolNumber,
- factory: ipv6.NewProtocol,
- nic1Addr: globalIPv6Addr1,
- nic2Addr: llAddr1,
- remoteAddr: llAddr2,
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(),
+ nic2AddrWithPrefix: llAddr1.WithPrefix(),
+ remoteAddr: llAddr2,
}
ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{
- proto: ipv6.ProtocolNumber,
- factory: ipv6.NewProtocol,
- nic1Addr: globalIPv6Addr1,
- nic2Addr: globalIPv6Addr2,
- remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(),
+ nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(),
+ remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
}
tests := []struct {
@@ -4036,8 +4144,8 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg netCfg
forwardingEnabled bool
- addrNIC tcpip.NICID
- localAddr tcpip.Address
+ addrNIC tcpip.NICID
+ localAddrWithPrefix tcpip.AddressWithPrefix
findRouteErr tcpip.Error
dependentOnForwarding bool
@@ -4047,7 +4155,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4056,7 +4164,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4065,7 +4173,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4074,7 +4182,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID1,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: true,
},
@@ -4083,7 +4191,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4092,7 +4200,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4101,7 +4209,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4110,7 +4218,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: true,
addrNIC: nicID2,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4118,7 +4226,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and localAddr on same NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: false,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4126,7 +4234,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and localAddr on same NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: false,
- localAddr: fakeNetCfg.nic2Addr,
+ localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4134,7 +4242,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and localAddr on different NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: false,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4142,7 +4250,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and localAddr on different NIC as route",
netCfg: fakeNetCfg,
forwardingEnabled: true,
- localAddr: fakeNetCfg.nic1Addr,
+ localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: true,
},
@@ -4166,7 +4274,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and link-local local addr with route on different NIC",
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: false,
- localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4174,7 +4282,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and link-local local addr with route on same NIC",
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: true,
- localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
@@ -4182,7 +4290,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with route on same NIC",
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: true,
- localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr,
+ localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4190,7 +4298,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and link-local local addr with route on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4198,7 +4306,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and link-local local addr with route on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4206,7 +4314,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with link-local remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4214,7 +4322,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and global local addr with link-local remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4222,7 +4330,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with link-local multicast remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4230,7 +4338,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and global local addr with link-local multicast remote on different NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix,
findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
@@ -4238,7 +4346,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding disabled and global local addr with link-local multicast remote on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: false,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4246,7 +4354,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
name: "forwarding enabled and global local addr with link-local multicast remote on same NIC",
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: true,
- localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr,
+ localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix,
findRouteErr: nil,
dependentOnForwarding: false,
},
@@ -4268,12 +4376,20 @@ func TestFindRouteWithForwarding(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err)
}
- if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err)
+ protocolAddr1 := tcpip.ProtocolAddress{
+ Protocol: test.netCfg.proto,
+ AddressWithPrefix: test.netCfg.nic1AddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err)
}
- if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err)
+ protocolAddr2 := tcpip.ProtocolAddress{
+ Protocol: test.netCfg.proto,
+ AddressWithPrefix: test.netCfg.nic2AddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err)
}
if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, test.forwardingEnabled); err != nil {
@@ -4282,20 +4398,20 @@ func TestFindRouteWithForwarding(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}})
- r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
+ r, err := s.FindRoute(test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
if err == nil {
defer r.Release()
}
if diff := cmp.Diff(test.findRouteErr, err); diff != "" {
- t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, diff)
+ t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, diff)
}
if test.findRouteErr != nil {
return
}
- if r.LocalAddress() != test.localAddr {
- t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddr)
+ if r.LocalAddress() != test.localAddrWithPrefix.Address {
+ t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddrWithPrefix.Address)
}
if r.RemoteAddress() != test.netCfg.remoteAddr {
t.Errorf("got r.RemoteAddress() = %s, want = %s", r.RemoteAddress(), test.netCfg.remoteAddr)
@@ -4318,8 +4434,8 @@ func TestFindRouteWithForwarding(t *testing.T) {
if !ok {
t.Fatal("packet not sent through ep2")
}
- if pkt.Route.LocalAddress != test.localAddr {
- t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr)
+ if pkt.Route.LocalAddress != test.localAddrWithPrefix.Address {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddrWithPrefix.Address)
}
if pkt.Route.RemoteAddress != test.netCfg.remoteAddr {
t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr)
diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go
index dc7289441..a941091b0 100644
--- a/pkg/tcpip/stack/tcp.go
+++ b/pkg/tcpip/stack/tcp.go
@@ -289,6 +289,12 @@ type TCPSenderState struct {
// RACKState holds the state related to RACK loss detection algorithm.
RACKState TCPRACKState
+
+ // RetransmitTS records the timestamp used to detect spurious recovery.
+ RetransmitTS uint32
+
+ // SpuriousRecovery indicates if the sender entered recovery spuriously.
+ SpuriousRecovery bool
}
// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint.
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 824cf6526..3474c292a 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -32,11 +32,13 @@ type protocolIDs struct {
// transportEndpoints manages all endpoints of a given protocol. It has its own
// mutex so as to reduce interference between protocols.
type transportEndpoints struct {
- // mu protects all fields of the transportEndpoints.
- mu sync.RWMutex
+ mu sync.RWMutex
+ // +checklocks:mu
endpoints map[TransportEndpointID]*endpointsByNIC
// rawEndpoints contains endpoints for raw sockets, which receive all
// traffic of a given protocol regardless of port.
+ //
+ // +checklocks:mu
rawEndpoints []RawTransportEndpoint
}
@@ -69,7 +71,7 @@ func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint {
// descending order of match quality. If a call to yield returns false,
// iterEndpointsLocked stops iteration and returns immediately.
//
-// Preconditions: eps.mu must be locked.
+// +checklocksread:eps.mu
func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) {
// Try to find a match with the id as provided.
if ep, ok := eps.endpoints[id]; ok {
@@ -110,7 +112,7 @@ func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield
// findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in
// descending order of match quality.
//
-// Preconditions: eps.mu must be locked.
+// +checklocksread:eps.mu
func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC {
var matchedEPs []*endpointsByNIC
eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
@@ -122,7 +124,7 @@ func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []
// findEndpointLocked returns the endpoint that most closely matches the given id.
//
-// Preconditions: eps.mu must be locked.
+// +checklocksread:eps.mu
func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC {
var matchedEP *endpointsByNIC
eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
@@ -133,10 +135,12 @@ func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpo
}
type endpointsByNIC struct {
- mu sync.RWMutex
- endpoints map[tcpip.NICID]*multiPortEndpoint
// seed is a random secret for a jenkins hash.
seed uint32
+
+ mu sync.RWMutex
+ // +checklocks:mu
+ endpoints map[tcpip.NICID]*multiPortEndpoint
}
func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
@@ -171,7 +175,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet
return true
}
// multiPortEndpoints are guaranteed to have at least one element.
- transEP := selectEndpoint(id, mpep, epsByNIC.seed)
+ transEP := mpep.selectEndpoint(id, epsByNIC.seed)
if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
queuedProtocol.QueuePacket(transEP, id, pkt)
epsByNIC.mu.RUnlock()
@@ -200,7 +204,7 @@ func (epsByNIC *endpointsByNIC) handleError(n *nic, id TransportEndpointID, tran
// broadcast like we are doing with handlePacket above?
// multiPortEndpoints are guaranteed to have at least one element.
- selectEndpoint(id, mpep, epsByNIC.seed).HandleError(transErr, pkt)
+ mpep.selectEndpoint(id, epsByNIC.seed).HandleError(transErr, pkt)
}
// registerEndpoint returns true if it succeeds. It fails and returns
@@ -333,15 +337,18 @@ func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber
//
// +stateify savable
type multiPortEndpoint struct {
- mu sync.RWMutex `state:"nosave"`
demux *transportDemuxer
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
+ flags ports.FlagCounter
+
+ mu sync.RWMutex `state:"nosave"`
// endpoints stores the transport endpoints in the order in which they
// were bound. This is required for UDP SO_REUSEADDR.
+ //
+ // +checklocks:mu
endpoints []TransportEndpoint
- flags ports.FlagCounter
}
func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint {
@@ -362,13 +369,16 @@ func reciprocalScale(val, n uint32) uint32 {
// selectEndpoint calculates a hash of destination and source addresses and
// ports then uses it to select a socket. In this case, all packets from one
// address will be sent to same endpoint.
-func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint {
- if len(mpep.endpoints) == 1 {
- return mpep.endpoints[0]
+func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID, seed uint32) TransportEndpoint {
+ ep.mu.RLock()
+ defer ep.mu.RUnlock()
+
+ if len(ep.endpoints) == 1 {
+ return ep.endpoints[0]
}
- if mpep.flags.SharedFlags().ToFlags().Effective().MostRecent {
- return mpep.endpoints[len(mpep.endpoints)-1]
+ if ep.flags.SharedFlags().ToFlags().Effective().MostRecent {
+ return ep.endpoints[len(ep.endpoints)-1]
}
payload := []byte{
@@ -384,8 +394,8 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
h.Write([]byte(id.RemoteAddress))
hash := h.Sum32()
- idx := reciprocalScale(hash, uint32(len(mpep.endpoints)))
- return mpep.endpoints[idx]
+ idx := reciprocalScale(hash, uint32(len(ep.endpoints)))
+ return ep.endpoints[idx]
}
func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) {
@@ -657,7 +667,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN
}
}
- ep := selectEndpoint(id, mpep, epsByNIC.seed)
+ ep := mpep.selectEndpoint(id, epsByNIC.seed)
epsByNIC.mu.RUnlock()
return ep
}
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 45b09110d..cd3a8c25a 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -35,7 +35,7 @@ import (
const (
testSrcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testDstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ testDstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
testSrcAddrV4 = "\x0a\x00\x00\x01"
testDstAddrV4 = "\x0a\x00\x00\x02"
@@ -64,12 +64,20 @@ func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICI
}
linkEps[linkEpID] = channelEp
- if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil {
- t.Fatalf("AddAddress IPv4 failed: %s", err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(testDstAddrV4).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(linkEpID, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV4, err)
}
- if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil {
- t.Fatalf("AddAddress IPv6 failed: %s", err)
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: testDstAddrV6.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(linkEpID, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV6, err)
}
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 839178809..51870d03f 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -331,8 +331,11 @@ func (*fakeTransportProtocol) Wait() {}
// Parse implements TransportProtocol.Parse.
func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool {
- _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen)
- return ok
+ if _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen); ok {
+ pkt.TransportProtocolNumber = fakeTransNumber
+ return true
+ }
+ return false
}
func fakeTransFactory(s *stack.Stack) stack.TransportProtocol {
@@ -357,8 +360,15 @@ func TestTransportReceive(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Create endpoint and connect to remote address.
@@ -428,8 +438,15 @@ func TestTransportControlReceive(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
// Create endpoint and connect to remote address.
@@ -497,8 +514,15 @@ func TestTransportSend(t *testing.T) {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: fakeDefaultPrefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err)
}
{
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 55683b4fb..460a6afaf 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -19,7 +19,7 @@
// The starting point is the creation and configuration of a stack. A stack can
// be created by calling the New() function of the tcpip/stack/stack package;
// configuring a stack involves creating NICs (via calls to Stack.CreateNIC()),
-// adding network addresses (via calls to Stack.AddAddress()), and
+// adding network addresses (via calls to Stack.AddProtocolAddress()), and
// setting a route table (via a call to Stack.SetRouteTable()).
//
// Once a stack is configured, endpoints can be created by calling
@@ -423,9 +423,9 @@ type ControlMessages struct {
// HasTimestamp indicates whether Timestamp is valid/set.
HasTimestamp bool
- // Timestamp is the time (in ns) that the last packet used to create
- // the read data was received.
- Timestamp int64
+ // Timestamp is the time that the last packet used to create the read data
+ // was received.
+ Timestamp time.Time `state:".(int64)"`
// HasInq indicates whether Inq is valid/set.
HasInq bool
@@ -451,6 +451,12 @@ type ControlMessages struct {
// PacketInfo holds interface and address data on an incoming packet.
PacketInfo IPPacketInfo
+ // HasIPv6PacketInfo indicates whether IPv6PacketInfo is set.
+ HasIPv6PacketInfo bool
+
+ // IPv6PacketInfo holds interface and address data on an incoming packet.
+ IPv6PacketInfo IPv6PacketInfo
+
// HasOriginalDestinationAddress indicates whether OriginalDstAddress is
// set.
HasOriginalDstAddress bool
@@ -465,10 +471,10 @@ type ControlMessages struct {
// PacketOwner is used to get UID and GID of the packet.
type PacketOwner interface {
- // UID returns KUID of the packet.
+ // KUID returns KUID of the packet.
KUID() uint32
- // GID returns KGID of the packet.
+ // KGID returns KGID of the packet.
KGID() uint32
}
@@ -1164,6 +1170,14 @@ type IPPacketInfo struct {
DestinationAddr Address
}
+// IPv6PacketInfo is the message structure for IPV6_PKTINFO.
+//
+// +stateify savable
+type IPv6PacketInfo struct {
+ Addr Address
+ NIC NICID
+}
+
// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to
// get/set the default, min and max send buffer sizes.
type SendBufferSizeOption struct {
@@ -1231,11 +1245,11 @@ type Route struct {
// String implements the fmt.Stringer interface.
func (r Route) String() string {
var out strings.Builder
- fmt.Fprintf(&out, "%s", r.Destination)
+ _, _ = fmt.Fprintf(&out, "%s", r.Destination)
if len(r.Gateway) > 0 {
- fmt.Fprintf(&out, " via %s", r.Gateway)
+ _, _ = fmt.Fprintf(&out, " via %s", r.Gateway)
}
- fmt.Fprintf(&out, " nic %d", r.NIC)
+ _, _ = fmt.Fprintf(&out, " nic %d", r.NIC)
return out.String()
}
@@ -1255,6 +1269,8 @@ type TransportProtocolNumber uint32
type NetworkProtocolNumber uint32
// A StatCounter keeps track of a statistic.
+//
+// +stateify savable
type StatCounter struct {
count atomicbitops.AlignedAtomicUint64
}
@@ -1270,7 +1286,7 @@ func (s *StatCounter) Decrement() {
}
// Value returns the current value of the counter.
-func (s *StatCounter) Value(name ...string) uint64 {
+func (s *StatCounter) Value(...string) uint64 {
return s.count.Load()
}
@@ -1849,6 +1865,10 @@ type TCPStats struct {
// SegmentsAckedWithDSACK is the number of segments acknowledged with
// DSACK.
SegmentsAckedWithDSACK *StatCounter
+
+ // SpuriousRecovery is the number of times the connection entered loss
+ // recovery spuriously.
+ SpuriousRecovery *StatCounter
}
// UDPStats collects UDP-specific stats.
@@ -1981,6 +2001,8 @@ type Stats struct {
}
// ReceiveErrors collects packet receive errors within transport endpoint.
+//
+// +stateify savable
type ReceiveErrors struct {
// ReceiveBufferOverflow is the number of received packets dropped
// due to the receive buffer being full.
@@ -1998,8 +2020,10 @@ type ReceiveErrors struct {
ChecksumErrors StatCounter
}
-// SendErrors collects packet send errors within the transport layer for
-// an endpoint.
+// SendErrors collects packet send errors within the transport layer for an
+// endpoint.
+//
+// +stateify savable
type SendErrors struct {
// SendToNetworkFailed is the number of packets failed to be written to
// the network endpoint.
@@ -2010,6 +2034,8 @@ type SendErrors struct {
}
// ReadErrors collects segment read errors from an endpoint read call.
+//
+// +stateify savable
type ReadErrors struct {
// ReadClosed is the number of received packet drops because the endpoint
// was shutdown for read.
@@ -2025,6 +2051,8 @@ type ReadErrors struct {
}
// WriteErrors collects packet write errors from an endpoint write call.
+//
+// +stateify savable
type WriteErrors struct {
// WriteClosed is the number of packet drops because the endpoint
// was shutdown for write.
@@ -2040,6 +2068,8 @@ type WriteErrors struct {
}
// TransportEndpointStats collects statistics about the endpoint.
+//
+// +stateify savable
type TransportEndpointStats struct {
// PacketsReceived is the number of successful packet receives.
PacketsReceived StatCounter
diff --git a/pkg/tcpip/tcpip_state.go b/pkg/tcpip/tcpip_state.go
new file mode 100644
index 000000000..1953e24a1
--- /dev/null
+++ b/pkg/tcpip/tcpip_state.go
@@ -0,0 +1,27 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip
+
+import (
+ "time"
+)
+
+func (c *ControlMessages) saveTimestamp() int64 {
+ return c.Timestamp.UnixNano()
+}
+
+func (c *ControlMessages) loadTimestamp(nsec int64) {
+ c.Timestamp = time.Unix(0, nsec)
+}
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 181ef799e..99f4d4d0e 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -34,12 +34,16 @@ go_test(
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/tests/utils",
"//pkg/tcpip/testutil",
+ "//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
@@ -139,3 +143,25 @@ go_test(
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
+
+go_test(
+ name = "istio_test",
+ size = "small",
+ srcs = ["istio_test.go"],
+ deps = [
+ "//pkg/context",
+ "//pkg/rand",
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/adapters/gonet",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/link/pipe",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
+ "//pkg/tcpip/transport/tcp",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index 92fa6257d..6e1d4720d 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -473,11 +473,19 @@ func TestMulticastForwarding(t *testing.T) {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv4Addr,
}
- if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err)
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
+ }
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv6Addr,
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
}
if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
@@ -612,8 +620,8 @@ func TestPerInterfaceForwarding(t *testing.T) {
addr: utils.RouterNIC2IPv6Addr,
},
} {
- if err := s.AddProtocolAddress(add.nicID, add.addr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", add.nicID, add.addr, err)
+ if err := s.AddProtocolAddress(add.nicID, add.addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", add.nicID, add.addr, err)
}
}
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index f9ab7d0af..957a779bf 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -15,19 +15,24 @@
package iptables_test
import (
+ "bytes"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
"gvisor.dev/gvisor/pkg/tcpip/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
)
type inputIfNameMatcher struct {
@@ -49,10 +54,10 @@ const (
nicName = "nic1"
anotherNicName = "nic2"
linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
- srcAddrV4 = "\x0a\x00\x00\x01"
- dstAddrV4 = "\x0a\x00\x00\x02"
- srcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- dstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ srcAddrV4 = tcpip.Address("\x0a\x00\x00\x01")
+ dstAddrV4 = tcpip.Address("\x0a\x00\x00\x02")
+ srcAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ dstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
payloadSize = 20
)
@@ -66,8 +71,12 @@ func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) {
if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, dstAddrV6); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, dstAddrV6, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: dstAddrV6.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
return s, e
}
@@ -82,8 +91,12 @@ func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) {
if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
}
- if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, dstAddrV4); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, dstAddrV4, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: dstAddrV4.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
return s, e
}
@@ -601,11 +614,19 @@ func TestIPTableWritePackets(t *testing.T) {
if err := s.CreateNIC(nicID, &e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, srcAddrV6); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, srcAddrV6, err)
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: srcAddrV6.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err)
+ }
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: srcAddrV4.WithPrefix(),
}
- if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, srcAddrV4); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, srcAddrV4, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -856,11 +877,19 @@ func TestForwardingHook(t *testing.T) {
t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv4Addr.Address.WithPrefix(),
}
- if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err)
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
+ }
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv6Addr.Address.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
}
if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
@@ -1037,22 +1066,22 @@ func TestInputHookWithLocalForwarding(t *testing.T) {
if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil {
t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err)
}
- if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv4Addr1, err)
+ if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv4Addr1, err)
}
- if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv6Addr1, err)
+ if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv6Addr1, err)
}
e2 := channel.New(1, header.IPv6MinimumMTU, "")
if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil {
t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
}
- if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv4Addr2, err)
+ if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv4Addr2, err)
}
- if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv6Addr2, err)
+ if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv6Addr2, err)
}
if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
@@ -1132,3 +1161,621 @@ func TestInputHookWithLocalForwarding(t *testing.T) {
})
}
}
+
+func TestNAT(t *testing.T) {
+ const listenPort uint16 = 8080
+
+ type endpointAndAddresses struct {
+ serverEP tcpip.Endpoint
+ serverAddr tcpip.FullAddress
+ serverReadableCH chan struct{}
+ serverConnectAddr tcpip.Address
+
+ clientEP tcpip.Endpoint
+ clientAddr tcpip.Address
+ clientReadableCH chan struct{}
+ clientConnectAddr tcpip.FullAddress
+ }
+
+ newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.ReadableEvents)
+ t.Cleanup(func() {
+ wq.EventUnregister(&we)
+ })
+
+ ep, err := s.NewEndpoint(transProto, netProto, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err)
+ }
+ t.Cleanup(ep.Close)
+
+ return ep, ch
+ }
+
+ setupNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, hook stack.Hook, filter stack.IPHeaderFilter, target stack.Target) {
+ t.Helper()
+
+ ipv6 := netProto == ipv6.ProtocolNumber
+ ipt := s.IPTables()
+ table := ipt.GetTable(stack.NATID, ipv6)
+ ruleIdx := table.BuiltinChains[hook]
+ table.Rules[ruleIdx].Filter = filter
+ table.Rules[ruleIdx].Target = target
+ // Make sure the packet is not dropped by the next rule.
+ table.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err)
+ }
+ }
+
+ setupDNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, target stack.Target) {
+ t.Helper()
+
+ setupNAT(
+ t,
+ s,
+ netProto,
+ stack.Prerouting,
+ stack.IPHeaderFilter{
+ Protocol: transProto,
+ CheckProtocol: true,
+ InputInterface: utils.RouterNIC2Name,
+ },
+ target)
+ }
+
+ setupSNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, target stack.Target) {
+ t.Helper()
+
+ setupNAT(
+ t,
+ s,
+ netProto,
+ stack.Postrouting,
+ stack.IPHeaderFilter{
+ Protocol: transProto,
+ CheckProtocol: true,
+ OutputInterface: utils.RouterNIC1Name,
+ },
+ target)
+ }
+
+ type natType struct {
+ name string
+ setupNAT func(_ *testing.T, _ *stack.Stack, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address)
+ }
+
+ snatTypes := []natType{
+ {
+ name: "SNAT",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, _ tcpip.Address) {
+ t.Helper()
+
+ setupSNAT(t, s, netProto, transProto, &stack.SNATTarget{NetworkProtocol: netProto, Addr: snatAddr})
+ },
+ },
+ {
+ name: "Masquerade",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, _ tcpip.Address) {
+ t.Helper()
+
+ setupSNAT(t, s, netProto, transProto, &stack.MasqueradeTarget{NetworkProtocol: netProto})
+ },
+ },
+ }
+ dnatTypes := []natType{
+ {
+ name: "Redirect",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, _ tcpip.Address) {
+ t.Helper()
+
+ setupDNAT(t, s, netProto, transProto, &stack.RedirectTarget{NetworkProtocol: netProto, Port: listenPort})
+ },
+ },
+ {
+ name: "DNAT",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, dnatAddr tcpip.Address) {
+ t.Helper()
+
+ setupDNAT(t, s, netProto, transProto, &stack.DNATTarget{NetworkProtocol: netProto, Addr: dnatAddr, Port: listenPort})
+ },
+ },
+ }
+
+ setupTwiceNAT := func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, dnatAddr tcpip.Address, snatTarget stack.Target) {
+ t.Helper()
+
+ ipv6 := netProto == ipv6.ProtocolNumber
+ ipt := s.IPTables()
+
+ table := stack.Table{
+ Rules: []stack.Rule{
+ // Prerouting
+ {
+ Filter: stack.IPHeaderFilter{
+ Protocol: transProto,
+ CheckProtocol: true,
+ InputInterface: utils.RouterNIC2Name,
+ },
+ Target: &stack.DNATTarget{NetworkProtocol: netProto, Addr: dnatAddr, Port: listenPort},
+ },
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Input
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Forward
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Output
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Postrouting
+ {
+ Filter: stack.IPHeaderFilter{
+ Protocol: transProto,
+ CheckProtocol: true,
+ OutputInterface: utils.RouterNIC1Name,
+ },
+ Target: snatTarget,
+ },
+ {
+ Target: &stack.AcceptTarget{},
+ },
+ },
+ BuiltinChains: [stack.NumHooks]int{
+ stack.Prerouting: 0,
+ stack.Input: 2,
+ stack.Forward: 3,
+ stack.Output: 4,
+ stack.Postrouting: 5,
+ },
+ }
+
+ if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err)
+ }
+ }
+ twiceNATTypes := []natType{
+ {
+ name: "DNAT-Masquerade",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address) {
+ t.Helper()
+
+ setupTwiceNAT(t, s, netProto, transProto, dnatAddr, &stack.MasqueradeTarget{NetworkProtocol: netProto})
+ },
+ },
+ {
+ name: "DNAT-SNAT",
+ setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address) {
+ t.Helper()
+
+ setupTwiceNAT(t, s, netProto, transProto, dnatAddr, &stack.SNATTarget{NetworkProtocol: netProto, Addr: snatAddr})
+ },
+ },
+ }
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ // Setups up the stacks in such a way that:
+ //
+ // - Host2 is the client for all tests.
+ // - When performing SNAT only:
+ // + Host1 is the server.
+ // + NAT will transform client-originating packets' source addresses to
+ // the router's NIC1's address before reaching Host1.
+ // - When performing DNAT only:
+ // + Router is the server.
+ // + Client will send packets directed to Host1.
+ // + NAT will transform client-originating packets' destination addresses
+ // to the router's NIC2's address.
+ // - When performing Twice-NAT:
+ // + Host1 is the server.
+ // + Client will send packets directed to router's NIC2.
+ // + NAT will transform client originating packets' destination addresses
+ // to Host1's address.
+ // + NAT will transform client-originating packets' source addresses to
+ // the router's NIC1's address before reaching Host1.
+ epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses
+ natTypes []natType
+ }{
+ {
+ name: "IPv4 SNAT",
+ netProto: ipv4.ProtocolNumber,
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ listenerStack := host1Stack
+ serverAddr := tcpip.FullAddress{
+ Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
+ Port: listenPort,
+ }
+ serverConnectAddr := utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address
+ clientConnectPort := serverAddr.Port
+ ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: serverAddr,
+ serverReadableCH: ep1WECH,
+ serverConnectAddr: serverConnectAddr,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ clientConnectAddr: tcpip.FullAddress{
+ Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
+ Port: clientConnectPort,
+ },
+ }
+ },
+ natTypes: snatTypes,
+ },
+ {
+ name: "IPv4 DNAT",
+ netProto: ipv4.ProtocolNumber,
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ // If we are performing DNAT, then the packet will be redirected
+ // to the router.
+ listenerStack := routerStack
+ serverAddr := tcpip.FullAddress{
+ Addr: utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
+ Port: listenPort,
+ }
+ serverConnectAddr := utils.Host2IPv4Addr.AddressWithPrefix.Address
+ // DNAT will update the destination port to what the server is
+ // bound to.
+ clientConnectPort := serverAddr.Port + 1
+ ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: serverAddr,
+ serverReadableCH: ep1WECH,
+ serverConnectAddr: serverConnectAddr,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ clientConnectAddr: tcpip.FullAddress{
+ Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
+ Port: clientConnectPort,
+ },
+ }
+ },
+ natTypes: dnatTypes,
+ },
+ {
+ name: "IPv4 Twice-NAT",
+ netProto: ipv4.ProtocolNumber,
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ listenerStack := host1Stack
+ serverAddr := tcpip.FullAddress{
+ Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
+ Port: listenPort,
+ }
+ serverConnectAddr := utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address
+ clientConnectPort := serverAddr.Port
+ ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: serverAddr,
+ serverReadableCH: ep1WECH,
+ serverConnectAddr: serverConnectAddr,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ clientConnectAddr: tcpip.FullAddress{
+ Addr: utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
+ Port: clientConnectPort,
+ },
+ }
+ },
+ natTypes: twiceNATTypes,
+ },
+ {
+ name: "IPv6 SNAT",
+ netProto: ipv6.ProtocolNumber,
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ listenerStack := host1Stack
+ serverAddr := tcpip.FullAddress{
+ Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
+ Port: listenPort,
+ }
+ serverConnectAddr := utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address
+ clientConnectPort := serverAddr.Port
+ ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: serverAddr,
+ serverReadableCH: ep1WECH,
+ serverConnectAddr: serverConnectAddr,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ clientConnectAddr: tcpip.FullAddress{
+ Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
+ Port: clientConnectPort,
+ },
+ }
+ },
+ natTypes: snatTypes,
+ },
+ {
+ name: "IPv6 DNAT",
+ netProto: ipv6.ProtocolNumber,
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ // If we are performing DNAT, then the packet will be redirected
+ // to the router.
+ listenerStack := routerStack
+ serverAddr := tcpip.FullAddress{
+ Addr: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
+ Port: listenPort,
+ }
+ serverConnectAddr := utils.Host2IPv6Addr.AddressWithPrefix.Address
+ // DNAT will update the destination port to what the server is
+ // bound to.
+ clientConnectPort := serverAddr.Port + 1
+ ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: serverAddr,
+ serverReadableCH: ep1WECH,
+ serverConnectAddr: serverConnectAddr,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ clientConnectAddr: tcpip.FullAddress{
+ Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
+ Port: clientConnectPort,
+ },
+ }
+ },
+ natTypes: dnatTypes,
+ },
+ {
+ name: "IPv6 Twice-NAT",
+ netProto: ipv6.ProtocolNumber,
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses {
+ t.Helper()
+
+ listenerStack := host1Stack
+ serverAddr := tcpip.FullAddress{
+ Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address,
+ Port: listenPort,
+ }
+ serverConnectAddr := utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address
+ clientConnectPort := serverAddr.Port
+ ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: serverAddr,
+ serverReadableCH: ep1WECH,
+ serverConnectAddr: serverConnectAddr,
+
+ clientEP: ep2,
+ clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ clientConnectAddr: tcpip.FullAddress{
+ Addr: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address,
+ Port: clientConnectPort,
+ },
+ }
+ },
+ natTypes: twiceNATTypes,
+ },
+ }
+
+ subTests := []struct {
+ name string
+ proto tcpip.TransportProtocolNumber
+ expectedConnectErr tcpip.Error
+ setupServer func(t *testing.T, ep tcpip.Endpoint)
+ setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{})
+ needRemoteAddr bool
+ }{
+ {
+ name: "UDP",
+ proto: udp.ProtocolNumber,
+ expectedConnectErr: nil,
+ setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+
+ if err := ep.Connect(clientAddr); err != nil {
+ t.Fatalf("ep.Connect(%#v): %s", clientAddr, err)
+ }
+ return nil, nil
+ },
+ needRemoteAddr: true,
+ },
+ {
+ name: "TCP",
+ proto: tcp.ProtocolNumber,
+ expectedConnectErr: &tcpip.ErrConnectStarted{},
+ setupServer: func(t *testing.T, ep tcpip.Endpoint) {
+ t.Helper()
+
+ if err := ep.Listen(1); err != nil {
+ t.Fatalf("ep.Listen(1): %s", err)
+ }
+ },
+ setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+
+ var addr tcpip.FullAddress
+ for {
+ newEP, wq, err := ep.Accept(&addr)
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ <-ch
+ continue
+ }
+ if err != nil {
+ t.Fatalf("ep.Accept(_): %s", err)
+ }
+ if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath(
+ "NIC",
+ )); diff != "" {
+ t.Errorf("accepted address mismatch (-want +got):\n%s", diff)
+ }
+
+ we, newCH := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.ReadableEvents)
+ return newEP, newCH
+ }
+ },
+ needRemoteAddr: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ for _, natType := range test.natTypes {
+ t.Run(natType.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
+ }
+
+ host1Stack := stack.New(stackOpts)
+ routerStack := stack.New(stackOpts)
+ host2Stack := stack.New(stackOpts)
+ utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack)
+
+ epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto)
+ natType.setupNAT(t, routerStack, test.netProto, subTest.proto, epsAndAddrs.serverConnectAddr, epsAndAddrs.serverAddr.Addr)
+
+ if err := epsAndAddrs.serverEP.Bind(epsAndAddrs.serverAddr); err != nil {
+ t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", epsAndAddrs.serverAddr, err)
+ }
+ clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
+ if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
+ }
+
+ if subTest.setupServer != nil {
+ subTest.setupServer(t, epsAndAddrs.serverEP)
+ }
+ {
+ err := epsAndAddrs.clientEP.Connect(epsAndAddrs.clientConnectAddr)
+ if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" {
+ t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", epsAndAddrs.clientConnectAddr, diff)
+ }
+ }
+ serverConnectAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverConnectAddr}
+ if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err)
+ } else {
+ serverConnectAddr.Port = addr.Port
+ }
+
+ serverEP := epsAndAddrs.serverEP
+ serverCH := epsAndAddrs.serverReadableCH
+ if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, serverConnectAddr); ep != nil {
+ defer ep.Close()
+ serverEP = ep
+ serverCH = ch
+ }
+
+ write := func(ep tcpip.Endpoint, data []byte) {
+ t.Helper()
+
+ var r bytes.Reader
+ r.Reset(data)
+ var wOpts tcpip.WriteOptions
+ n, err := ep.Write(&r, wOpts)
+ if err != nil {
+ t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
+ }
+ if want := int64(len(data)); n != want {
+ t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want)
+ }
+ }
+
+ read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) {
+ t.Helper()
+
+ var buf bytes.Buffer
+ var res tcpip.ReadResult
+ for {
+ var err tcpip.Error
+ opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
+ res, err = ep.Read(&buf, opts)
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ <-ch
+ continue
+ }
+ if err != nil {
+ t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
+ }
+ break
+ }
+
+ readResult := tcpip.ReadResult{
+ Count: len(data),
+ Total: len(data),
+ }
+ if subTest.needRemoteAddr {
+ readResult.RemoteAddr = expectedFrom
+ }
+ if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath(
+ "ControlMessages",
+ "RemoteAddr.NIC",
+ )); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff(buf.Bytes(), data); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ {
+ data := []byte{1, 2, 3, 4}
+ write(epsAndAddrs.clientEP, data)
+ read(serverCH, serverEP, data, serverConnectAddr)
+ }
+
+ {
+ data := []byte{5, 6, 7, 8, 9, 10, 11, 12}
+ write(serverEP, data)
+ read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, epsAndAddrs.clientConnectAddr)
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/istio_test.go b/pkg/tcpip/tests/integration/istio_test.go
new file mode 100644
index 000000000..95d994ef8
--- /dev/null
+++ b/pkg/tcpip/tests/integration/istio_test.go
@@ -0,0 +1,365 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package istio_test
+
+import (
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "strconv"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/link/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+)
+
+// testContext encapsulates the state required to run tests that simulate
+// an istio like environment.
+//
+// A diagram depicting the setup is shown below.
+// +-----------------------------------------------------------------------+
+// | +-------------------------------------------------+ |
+// | + ----------+ | + -----------------+ PROXY +----------+ | |
+// | | clientEP | | | serverListeningEP|--accepted-> | serverEP |-+ | |
+// | + ----------+ | + -----------------+ +----------+ | | |
+// | | -------|-------------+ +----------+ | | |
+// | | | | | proxyEP |-+ | |
+// | +-----redirect | +----------+ | |
+// | + ------------+---|------+---+ |
+// | | |
+// | Local Stack. | |
+// +-------------------------------------------------------|---------------+
+// |
+// +-----------------------------------------------------------------------+
+// | remoteStack | |
+// | +-------------SYN ---------------| |
+// | | | |
+// | +-------------------|--------------------------------|-_---+ |
+// | | + -----------------+ + ----------+ | | |
+// | | | remoteListeningEP|--accepted--->| remoteEP |<++ | |
+// | | + -----------------+ + ----------+ | |
+// | | Remote HTTP Server | |
+// | +----------------------------------------------------------+ |
+// +-----------------------------------------------------------------------+
+//
+type testContext struct {
+ // localServerListener is the listening port for the server which will proxy
+ // all traffic to the remote EP.
+ localServerListener *gonet.TCPListener
+
+ // remoteListenListener is the remote listening endpoint that will receive
+ // connections from server.
+ remoteServerListener *gonet.TCPListener
+
+ // localStack is the stack used to create client/server endpoints and
+ // also the stack on which we install NAT redirect rules.
+ localStack *stack.Stack
+
+ // remoteStack is the stack that represents a *remote* server.
+ remoteStack *stack.Stack
+
+ // defaultResponse is the response served by the HTTP server for all GET
+ defaultResponse []byte
+
+ // requests. wg is used to wait for HTTP server and Proxy to terminate before
+ // returning from cleanup.
+ wg sync.WaitGroup
+}
+
+func (ctx *testContext) cleanup() {
+ ctx.localServerListener.Close()
+ ctx.localStack.Close()
+ ctx.remoteServerListener.Close()
+ ctx.remoteStack.Close()
+ ctx.wg.Wait()
+}
+
+const (
+ localServerPort = 8080
+ remoteServerPort = 9090
+)
+
+var (
+ localIPv4Addr1 = testutil.MustParse4("10.0.0.1")
+ localIPv4Addr2 = testutil.MustParse4("10.0.0.2")
+ loopbackIPv4Addr = testutil.MustParse4("127.0.0.1")
+ remoteIPv4Addr1 = testutil.MustParse4("10.0.0.3")
+)
+
+func newTestContext(t *testing.T) *testContext {
+ t.Helper()
+ localNIC, remoteNIC := pipe.New("" /* linkAddr1 */, "" /* linkAddr2 */)
+
+ localStack := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ HandleLocal: true,
+ })
+
+ remoteStack := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ HandleLocal: true,
+ })
+
+ // Add loopback NIC. We need a loopback NIC as NAT redirect rule redirect to
+ // loopback address + specified port.
+ loopbackNIC := loopback.New()
+ const loopbackNICID = tcpip.NICID(1)
+ if err := localStack.CreateNIC(loopbackNICID, sniffer.New(loopbackNIC)); err != nil {
+ t.Fatalf("localStack.CreateNIC(%d, _): %s", loopbackNICID, err)
+ }
+ loopbackAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: loopbackIPv4Addr.WithPrefix(),
+ }
+ if err := localStack.AddProtocolAddress(loopbackNICID, loopbackAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("localStack.AddProtocolAddress(%d, %+v, {}): %s", loopbackNICID, loopbackAddr, err)
+ }
+
+ // Create linked NICs that connects the local and remote stack.
+ const localNICID = tcpip.NICID(2)
+ const remoteNICID = tcpip.NICID(3)
+ if err := localStack.CreateNIC(localNICID, sniffer.New(localNIC)); err != nil {
+ t.Fatalf("localStack.CreateNIC(%d, _): %s", localNICID, err)
+ }
+ if err := remoteStack.CreateNIC(remoteNICID, sniffer.New(remoteNIC)); err != nil {
+ t.Fatalf("remoteStack.CreateNIC(%d, _): %s", remoteNICID, err)
+ }
+
+ for _, addr := range []tcpip.Address{localIPv4Addr1, localIPv4Addr2} {
+ localProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: addr.WithPrefix(),
+ }
+ if err := localStack.AddProtocolAddress(localNICID, localProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("localStack.AddProtocolAddress(%d, %+v, {}): %s", localNICID, localProtocolAddr, err)
+ }
+ }
+
+ remoteProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: remoteIPv4Addr1.WithPrefix(),
+ }
+ if err := remoteStack.AddProtocolAddress(remoteNICID, remoteProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("remoteStack.AddProtocolAddress(%d, %+v, {}): %s", remoteNICID, remoteProtocolAddr, err)
+ }
+
+ // Setup route table for local and remote stacks.
+ localStack.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4LoopbackSubnet,
+ NIC: loopbackNICID,
+ },
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: localNICID,
+ },
+ })
+ remoteStack.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: remoteNICID,
+ },
+ })
+
+ const netProto = ipv4.ProtocolNumber
+ localServerAddress := tcpip.FullAddress{
+ Port: localServerPort,
+ }
+
+ localServerListener, err := gonet.ListenTCP(localStack, localServerAddress, netProto)
+ if err != nil {
+ t.Fatalf("gonet.ListenTCP(_, %+v, %d) = %s", localServerAddress, netProto, err)
+ }
+
+ remoteServerAddress := tcpip.FullAddress{
+ Port: remoteServerPort,
+ }
+ remoteServerListener, err := gonet.ListenTCP(remoteStack, remoteServerAddress, netProto)
+ if err != nil {
+ t.Fatalf("gonet.ListenTCP(_, %+v, %d) = %s", remoteServerAddress, netProto, err)
+ }
+
+ // Initialize a random default response served by the HTTP server.
+ defaultResponse := make([]byte, 512<<10)
+ if _, err := rand.Read(defaultResponse); err != nil {
+ t.Fatalf("rand.Read(buf) failed: %s", err)
+ }
+
+ tc := &testContext{
+ localServerListener: localServerListener,
+ remoteServerListener: remoteServerListener,
+ localStack: localStack,
+ remoteStack: remoteStack,
+ defaultResponse: defaultResponse,
+ }
+
+ tc.startServers(t)
+ return tc
+}
+
+func (ctx *testContext) startServers(t *testing.T) {
+ ctx.wg.Add(1)
+ go func() {
+ defer ctx.wg.Done()
+ ctx.startHTTPServer()
+ }()
+ ctx.wg.Add(1)
+ go func() {
+ defer ctx.wg.Done()
+ ctx.startTCPProxyServer(t)
+ }()
+}
+
+func (ctx *testContext) startTCPProxyServer(t *testing.T) {
+ t.Helper()
+ for {
+ conn, err := ctx.localServerListener.Accept()
+ if err != nil {
+ t.Logf("terminating local proxy server: %s", err)
+ return
+ }
+ // Start a goroutine to handle this inbound connection.
+ go func() {
+ remoteServerAddr := tcpip.FullAddress{
+ Addr: remoteIPv4Addr1,
+ Port: remoteServerPort,
+ }
+ localServerAddr := tcpip.FullAddress{
+ Addr: localIPv4Addr2,
+ }
+ serverConn, err := gonet.DialTCPWithBind(context.Background(), ctx.localStack, localServerAddr, remoteServerAddr, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Logf("gonet.DialTCP(_, %+v, %d) = %s", remoteServerAddr, ipv4.ProtocolNumber, err)
+ return
+ }
+ proxy(conn, serverConn)
+ t.Logf("proxying completed")
+ }()
+ }
+}
+
+// proxy transparently proxies the TCP payload from conn1 to conn2
+// and vice versa.
+func proxy(conn1, conn2 net.Conn) {
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ io.Copy(conn2, conn1)
+ conn1.Close()
+ conn2.Close()
+ }()
+ wg.Add(1)
+ go func() {
+ io.Copy(conn1, conn2)
+ conn1.Close()
+ conn2.Close()
+ }()
+ wg.Wait()
+}
+
+func (ctx *testContext) startHTTPServer() {
+ handlerFunc := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(ctx.defaultResponse))
+ })
+ s := &http.Server{
+ Handler: handlerFunc,
+ }
+ s.Serve(ctx.remoteServerListener)
+}
+
+func TestOutboundNATRedirect(t *testing.T) {
+ ctx := newTestContext(t)
+ defer ctx.cleanup()
+
+ // Install an IPTable rule to redirect all TCP traffic with the sourceIP of
+ // localIPv4Addr1 to the tcp proxy port.
+ ipt := ctx.localStack.IPTables()
+ tbl := ipt.GetTable(stack.NATID, false /* ipv6 */)
+ ruleIdx := tbl.BuiltinChains[stack.Output]
+ tbl.Rules[ruleIdx].Filter = stack.IPHeaderFilter{
+ Protocol: tcp.ProtocolNumber,
+ CheckProtocol: true,
+ Src: localIPv4Addr1,
+ SrcMask: tcpip.Address("\xff\xff\xff\xff"),
+ }
+ tbl.Rules[ruleIdx].Target = &stack.RedirectTarget{
+ Port: localServerPort,
+ NetworkProtocol: ipv4.ProtocolNumber,
+ }
+ tbl.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.NATID, tbl, false /* ipv6 */); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, false): %s", stack.NATID, err)
+ }
+
+ dialFunc := func(protocol, address string) (net.Conn, error) {
+ host, port, err := net.SplitHostPort(address)
+ if err != nil {
+ return nil, fmt.Errorf("unable to parse address: %s, err: %s", address, err)
+ }
+
+ remoteServerIP := net.ParseIP(host)
+ remoteServerPort, err := strconv.Atoi(port)
+ if err != nil {
+ return nil, fmt.Errorf("unable to parse port from string %s, err: %s", port, err)
+ }
+ remoteAddress := tcpip.FullAddress{
+ Addr: tcpip.Address(remoteServerIP.To4()),
+ Port: uint16(remoteServerPort),
+ }
+
+ // Dial with an explicit source address bound so that the redirect rule will
+ // be able to correctly redirect these packets.
+ localAddr := tcpip.FullAddress{Addr: localIPv4Addr1}
+ return gonet.DialTCPWithBind(context.Background(), ctx.localStack, localAddr, remoteAddress, ipv4.ProtocolNumber)
+ }
+
+ httpClient := &http.Client{
+ Transport: &http.Transport{
+ Dial: dialFunc,
+ },
+ }
+
+ serverURL := fmt.Sprintf("http://[%s]:%d/", net.IP(remoteIPv4Addr1), remoteServerPort)
+ response, err := httpClient.Get(serverURL)
+ if err != nil {
+ t.Fatalf("httpClient.Get(\"/\") failed: %s", err)
+ }
+ if got, want := response.StatusCode, http.StatusOK; got != want {
+ t.Fatalf("unexpected status code got: %d, want: %d", got, want)
+ }
+ body, err := io.ReadAll(response.Body)
+ if err != nil {
+ t.Fatalf("io.ReadAll(response.Body) failed: %s", err)
+ }
+ response.Body.Close()
+ if diff := cmp.Diff(body, ctx.defaultResponse); diff != "" {
+ t.Fatalf("unexpected response (-want +got): \n %s", diff)
+ }
+}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index 27caa0c28..95ddd8ec3 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -56,17 +56,17 @@ func setupStack(t *testing.T, stackOpts stack.Options, host1NICID, host2NICID tc
t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
}
- if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv4Addr1, err)
+ if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv4Addr1, err)
}
- if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv4Addr2, err)
+ if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv4Addr2, err)
}
- if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, utils.Ipv6Addr1, err)
+ if err := host1Stack.AddProtocolAddress(host1NICID, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", host1NICID, utils.Ipv6Addr1, err)
}
- if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, utils.Ipv6Addr2, err)
+ if err := host2Stack.AddProtocolAddress(host2NICID, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", host2NICID, utils.Ipv6Addr2, err)
}
host1Stack.SetRouteTable([]tcpip.Route{
@@ -568,8 +568,8 @@ func TestForwardingWithLinkResolutionFailure(t *testing.T) {
Protocol: test.networkProtocolNumber,
AddressWithPrefix: test.incomingAddr,
}
- if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingProtoAddr, err)
+ if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", incomingNICID, incomingProtoAddr, err)
}
// Set up endpoint through which we will attempt to forward packets.
@@ -582,8 +582,8 @@ func TestForwardingWithLinkResolutionFailure(t *testing.T) {
Protocol: test.networkProtocolNumber,
AddressWithPrefix: test.outgoingAddr,
}
- if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingProtoAddr, err)
+ if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", outgoingNICID, outgoingProtoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index b2008f0b2..f33223e79 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -195,8 +195,8 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) {
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, test.addAddress, err)
+ if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err)
}
s.SetRouteTable([]tcpip.Route{
{
@@ -290,8 +290,8 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) {
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
{
@@ -431,8 +431,8 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) {
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, test.addAddress, err)
+ if err := s.AddProtocolAddress(nicID, test.addAddress, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, test.addAddress, err)
}
s.SetRouteTable([]tcpip.Route{
{
@@ -693,21 +693,40 @@ func TestExternalLoopbackTraffic(t *testing.T) {
if err := s.CreateNIC(nicID1, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddAddressWithPrefix(nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv4.ProtocolNumber, utils.Ipv4Addr, err)
+ v4Addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv4Addr,
}
- if err := s.AddAddressWithPrefix(nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID1, ipv6.ProtocolNumber, utils.Ipv6Addr, err)
+ if err := s.AddProtocolAddress(nicID1, v4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v4Addr, err)
+ }
+ v6Addr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: utils.Ipv6Addr,
+ }
+ if err := s.AddProtocolAddress(nicID1, v6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, v6Addr, err)
}
if err := s.CreateNIC(nicID2, loopback.New()); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
}
- if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, ipv4Loopback); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, ipv4Loopback, err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: ipv4Loopback,
+ PrefixLen: 8,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err)
+ }
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: header.IPv6Loopback.WithPrefix(),
}
- if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, header.IPv6Loopback); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, header.IPv6Loopback, err)
+ if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err)
}
if test.forwarding {
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 2d0a6e6a7..7753e7d6e 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -119,12 +119,12 @@ func TestPingMulticastBroadcast(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: utils.Ipv4Addr}
- if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtoAddr, err)
}
ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: utils.Ipv6Addr}
- if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv6ProtoAddr, err)
+ if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtoAddr, err)
}
// Default routes for IPv4 and IPv6 so ICMP can find a route to the remote
@@ -396,8 +396,8 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
var wq waiter.Queue
@@ -474,8 +474,8 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
PrefixLen: 8,
},
}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -642,8 +642,8 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr}
- if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err)
+ if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
}
// Set the route table so that UDP can find a NIC that is
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
index ac3c703d4..422eb8408 100644
--- a/pkg/tcpip/tests/integration/route_test.go
+++ b/pkg/tcpip/tests/integration/route_test.go
@@ -47,7 +47,10 @@ func TestLocalPing(t *testing.T) {
// request/reply packets.
icmpDataOffset = 8
)
- ipv4Loopback := testutil.MustParse4("127.0.0.1")
+ ipv4Loopback := tcpip.AddressWithPrefix{
+ Address: testutil.MustParse4("127.0.0.1"),
+ PrefixLen: 8,
+ }
channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") }
channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) {
@@ -82,7 +85,7 @@ func TestLocalPing(t *testing.T) {
transProto tcpip.TransportProtocolNumber
netProto tcpip.NetworkProtocolNumber
linkEndpoint func() stack.LinkEndpoint
- localAddr tcpip.Address
+ localAddr tcpip.AddressWithPrefix
icmpBuf func(*testing.T) buffer.View
expectedConnectErr tcpip.Error
checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint)
@@ -101,7 +104,7 @@ func TestLocalPing(t *testing.T) {
transProto: icmp.ProtocolNumber6,
netProto: ipv6.ProtocolNumber,
linkEndpoint: loopback.New,
- localAddr: header.IPv6Loopback,
+ localAddr: header.IPv6Loopback.WithPrefix(),
icmpBuf: ipv6ICMPBuf,
checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
},
@@ -110,7 +113,7 @@ func TestLocalPing(t *testing.T) {
transProto: icmp.ProtocolNumber4,
netProto: ipv4.ProtocolNumber,
linkEndpoint: channelEP,
- localAddr: utils.Ipv4Addr.Address,
+ localAddr: utils.Ipv4Addr,
icmpBuf: ipv4ICMPBuf,
checkLinkEndpoint: channelEPCheck,
},
@@ -119,7 +122,7 @@ func TestLocalPing(t *testing.T) {
transProto: icmp.ProtocolNumber6,
netProto: ipv6.ProtocolNumber,
linkEndpoint: channelEP,
- localAddr: utils.Ipv6Addr.Address,
+ localAddr: utils.Ipv6Addr,
icmpBuf: ipv6ICMPBuf,
checkLinkEndpoint: channelEPCheck,
},
@@ -182,9 +185,13 @@ func TestLocalPing(t *testing.T) {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if len(test.localAddr) != 0 {
- if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err)
+ if len(test.localAddr.Address) != 0 {
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: test.netProto,
+ AddressWithPrefix: test.localAddr,
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
}
}
@@ -197,7 +204,7 @@ func TestLocalPing(t *testing.T) {
}
defer ep.Close()
- connAddr := tcpip.FullAddress{Addr: test.localAddr}
+ connAddr := tcpip.FullAddress{Addr: test.localAddr.Address}
if err := ep.Connect(connAddr); err != test.expectedConnectErr {
t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr)
}
@@ -229,8 +236,8 @@ func TestLocalPing(t *testing.T) {
if diff := cmp.Diff(buffer.View(w.Bytes()[icmpDataOffset:]), payload[icmpDataOffset:]); diff != "" {
t.Errorf("received data mismatch (-want +got):\n%s", diff)
}
- if rr.RemoteAddr.Addr != test.localAddr {
- t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr)
+ if rr.RemoteAddr.Addr != test.localAddr.Address {
+ t.Errorf("got addr.Addr = %s, want = %s", rr.RemoteAddr.Addr, test.localAddr.Address)
}
test.checkLinkEndpoint(t, e)
@@ -302,11 +309,12 @@ func TestLocalUDP(t *testing.T) {
}
if subTest.addAddress {
- if err := s.AddProtocolAddressWithOptions(nicID, test.canBePrimaryAddr, stack.CanBePrimaryEndpoint); err != nil {
- t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.canBePrimaryAddr, stack.FirstPrimaryEndpoint, err)
+ if err := s.AddProtocolAddress(nicID, test.canBePrimaryAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, test.canBePrimaryAddr, err)
}
- if err := s.AddProtocolAddressWithOptions(nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint, err)
+ properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint}
+ if err := s.AddProtocolAddress(nicID, test.firstPrimaryAddr, properties); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, %+v): %s", nicID, test.firstPrimaryAddr, properties, err)
}
}
diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go
index 2e6ae55ea..c69410859 100644
--- a/pkg/tcpip/tests/utils/utils.go
+++ b/pkg/tcpip/tests/utils/utils.go
@@ -40,6 +40,14 @@ const (
Host2NICID = 4
)
+// Common NIC names used by tests.
+const (
+ Host1NICName = "host1NIC"
+ RouterNIC1Name = "routerNIC1"
+ RouterNIC2Name = "routerNIC2"
+ Host2NICName = "host2NIC"
+)
+
// Common link addresses used by tests.
const (
LinkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
@@ -211,17 +219,29 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
host1NIC, routerNIC1 := pipe.New(LinkAddr1, LinkAddr2)
routerNIC2, host2NIC := pipe.New(LinkAddr3, LinkAddr4)
- if err := host1Stack.CreateNIC(Host1NICID, NewEthernetEndpoint(host1NIC)); err != nil {
- t.Fatalf("host1Stack.CreateNIC(%d, _): %s", Host1NICID, err)
+ {
+ opts := stack.NICOptions{Name: Host1NICName}
+ if err := host1Stack.CreateNICWithOptions(Host1NICID, NewEthernetEndpoint(host1NIC), opts); err != nil {
+ t.Fatalf("host1Stack.CreateNICWithOptions(%d, _, %#v): %s", Host1NICID, opts, err)
+ }
}
- if err := routerStack.CreateNIC(RouterNICID1, NewEthernetEndpoint(routerNIC1)); err != nil {
- t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID1, err)
+ {
+ opts := stack.NICOptions{Name: RouterNIC1Name}
+ if err := routerStack.CreateNICWithOptions(RouterNICID1, NewEthernetEndpoint(routerNIC1), opts); err != nil {
+ t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID1, opts, err)
+ }
}
- if err := routerStack.CreateNIC(RouterNICID2, NewEthernetEndpoint(routerNIC2)); err != nil {
- t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID2, err)
+ {
+ opts := stack.NICOptions{Name: RouterNIC2Name}
+ if err := routerStack.CreateNICWithOptions(RouterNICID2, NewEthernetEndpoint(routerNIC2), opts); err != nil {
+ t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID2, opts, err)
+ }
}
- if err := host2Stack.CreateNIC(Host2NICID, NewEthernetEndpoint(host2NIC)); err != nil {
- t.Fatalf("host2Stack.CreateNIC(%d, _): %s", Host2NICID, err)
+ {
+ opts := stack.NICOptions{Name: Host2NICName}
+ if err := host2Stack.CreateNICWithOptions(Host2NICID, NewEthernetEndpoint(host2NIC), opts); err != nil {
+ t.Fatalf("host2Stack.CreateNICWithOptions(%d, _, %#v): %s", Host2NICID, opts, err)
+ }
}
if err := routerStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
@@ -231,29 +251,29 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv6.ProtocolNumber, err)
}
- if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv4Addr, err)
+ if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv4Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv4Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv4Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv4Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv4Addr, err)
}
- if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv4Addr, err)
+ if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv4Addr, err)
}
- if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", Host1NICID, Host1IPv6Addr, err)
+ if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv6Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID1, RouterNIC1IPv6Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv6Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", RouterNICID2, RouterNIC2IPv6Addr, err)
+ if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv6Addr, err)
}
- if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", Host2NICID, Host2IPv6Addr, err)
+ if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv6Addr, err)
}
host1Stack.SetRouteTable([]tcpip.Route{
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index bbc0e3ecc..4718ec4ec 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -33,6 +33,8 @@ go_library(
"//pkg/tcpip/header",
"//pkg/tcpip/ports",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
+ "//pkg/tcpip/transport/internal/network",
"//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
"//pkg/waiter",
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 00497bf07..995f58616 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -15,6 +15,7 @@
package icmp
import (
+ "fmt"
"io"
"time"
@@ -24,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -35,15 +38,6 @@ type icmpPacket struct {
receivedAt time.Time `state:".(int64)"`
}
-type endpointState int
-
-const (
- stateInitial endpointState = iota
- stateBound
- stateConnected
- stateClosed
-)
-
// endpoint represents an ICMP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -51,14 +45,17 @@ const (
//
// +stateify savable
type endpoint struct {
- stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
+ transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
uniqueID uint64
+ net network.Endpoint
+ stats tcpip.TransportEndpointStats
+ ops tcpip.SocketOptions
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -70,38 +67,23 @@ type endpoint struct {
// The following fields are protected by the mu mutex.
mu sync.RWMutex `state:"nosave"`
- // shutdownFlags represent the current shutdown state of the endpoint.
- shutdownFlags tcpip.ShutdownFlags
- state endpointState
- route *stack.Route `state:"manual"`
- ttl uint8
- stats tcpip.TransportEndpointStats `state:"nosave"`
-
- // owner is used to get uid and gid of the packet.
- owner tcpip.PacketOwner
-
- // ops is used to get socket level options.
- ops tcpip.SocketOptions
-
// frozen indicates if the packets should be delivered to the endpoint
// during restore.
frozen bool
+ ident uint16
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
ep := &endpoint{
- stack: s,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- TransProto: transProto,
- },
+ stack: s,
+ transProto: transProto,
waiterQueue: waiterQueue,
- state: stateInitial,
uniqueID: s.UniqueID(),
}
ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
ep.ops.SetSendBufferSize(32*1024, false /* notify */)
ep.ops.SetReceiveBufferSize(32*1024, false /* notify */)
+ ep.net.Init(s, netProto, transProto, &ep.ops)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -128,35 +110,40 @@ func (e *endpoint) Abort() {
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
- e.mu.Lock()
- e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
- switch e.state {
- case stateBound, stateConnected:
- bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
- e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, bindToDevice)
- }
-
- // Close the receive list and drain it.
- e.rcvMu.Lock()
- e.rcvClosed = true
- e.rcvBufSize = 0
- for !e.rcvList.Empty() {
- p := e.rcvList.Front()
- e.rcvList.Remove(p)
- }
- e.rcvMu.Unlock()
+ notify := func() bool {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateClosed:
+ return false
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ info := e.net.Info()
+ info.ID.LocalPort = e.ident
+ e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{info.NetProto}, e.transProto, info.ID, e, ports.Flags{}, tcpip.NICID(e.ops.GetBindToDevice()))
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
+ }
- if e.route != nil {
- e.route.Release()
- e.route = nil
- }
+ e.net.Shutdown()
+ e.net.Close()
- // Update the state.
- e.state = stateClosed
+ e.rcvMu.Lock()
+ defer e.rcvMu.Unlock()
+ e.rcvClosed = true
+ e.rcvBufSize = 0
+ for !e.rcvList.Empty() {
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ }
- e.mu.Unlock()
+ return true
+ }()
- e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
+ if notify {
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
+ }
}
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
@@ -164,7 +151,7 @@ func (*endpoint) ModerateRecvBuf(int) {}
// SetOwner implements tcpip.Endpoint.SetOwner.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
- e.owner = owner
+ e.net.SetOwner(owner)
}
// Read implements tcpip.Endpoint.Read.
@@ -193,7 +180,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
Total: p.data.Size(),
ControlMessages: tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: p.receivedAt.UnixNano(),
+ Timestamp: p.receivedAt,
},
}
if opts.NeedRemoteAddr {
@@ -213,14 +200,13 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
-// +checklocks:e.mu
-func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
- switch e.state {
- case stateInitial:
- case stateConnected:
+// +checklocksread:e.mu
+func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
+ switch e.net.State() {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateConnected:
return false, nil
-
- case stateBound:
+ case transport.DatagramEndpointStateBound:
if to == nil {
return false, &tcpip.ErrDestinationRequired{}
}
@@ -235,7 +221,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
- if e.state != stateInitial {
+ if e.net.State() != transport.DatagramEndpointStateInitial {
return true, nil
}
@@ -270,27 +256,15 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return n, err
}
-func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
- // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
- if opts.More {
- return 0, &tcpip.ErrInvalidOptionValue{}
- }
-
- to := opts.To
-
+func (e *endpoint) prepareForWrite(opts tcpip.WriteOptions) (network.WriteContext, uint16, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- // If we've shutdown with SHUT_WR we are in an invalid state for sending.
- if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
- return 0, &tcpip.ErrClosedForSend{}
- }
-
// Prepare for write.
for {
- retry, err := e.prepareForWrite(to)
+ retry, err := e.prepareForWriteInner(opts.To)
if err != nil {
- return 0, err
+ return network.WriteContext{}, 0, err
}
if !retry {
@@ -298,36 +272,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
}
}
- route := e.route
- if to != nil {
- // Reject destination address if it goes through a different
- // NIC than the endpoint was bound to.
- nicID := to.NIC
- if nicID == 0 {
- nicID = tcpip.NICID(e.ops.GetBindToDevice())
- }
- if e.BindNICID != 0 {
- if nicID != 0 && nicID != e.BindNICID {
- return 0, &tcpip.ErrNoRoute{}
- }
-
- nicID = e.BindNICID
- }
-
- dst, netProto, err := e.checkV4MappedLocked(*to)
- if err != nil {
- return 0, err
- }
-
- // Find the endpoint.
- r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */)
- if err != nil {
- return 0, err
- }
- defer r.Release()
+ ctx, err := e.net.AcquireContextForWrite(opts)
+ return ctx, e.ident, err
+}
- route = r
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
+ ctx, ident, err := e.prepareForWrite(opts)
+ if err != nil {
+ return 0, err
}
+ defer ctx.Release()
// TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
v := make([]byte, p.Len())
@@ -335,17 +289,18 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return 0, &tcpip.ErrBadBuffer{}
}
- var err tcpip.Error
- switch e.NetProto {
+ switch netProto, pktInfo := e.net.NetProto(), ctx.PacketInfo(); netProto {
case header.IPv4ProtocolNumber:
- err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner)
+ if err := send4(e.stack, &ctx, ident, v, pktInfo.MaxHeaderLength); err != nil {
+ return 0, err
+ }
case header.IPv6ProtocolNumber:
- err = send6(route, e.ID.LocalPort, v, e.ttl)
- }
-
- if err != nil {
- return 0, err
+ if err := send6(e.stack, &ctx, ident, v, pktInfo.LocalAddress, pktInfo.RemoteAddress, pktInfo.MaxHeaderLength); err != nil {
+ return 0, err
+ }
+ default:
+ panic(fmt.Sprintf("unhandled network protocol = %d", netProto))
}
return int64(len(v)), nil
@@ -358,24 +313,17 @@ func (e *endpoint) HasNIC(id int32) bool {
return e.stack.HasNIC(tcpip.NICID(id))
}
-// SetSockOpt sets a socket option.
-func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error {
- return nil
+// SetSockOpt implements tcpip.Endpoint.
+func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
+ return e.net.SetSockOpt(opt)
}
-// SetSockOptInt sets a socket option. Currently not supported.
+// SetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
- switch opt {
- case tcpip.TTLOption:
- e.mu.Lock()
- e.ttl = uint8(v)
- e.mu.Unlock()
-
- }
- return nil
+ return e.net.SetSockOptInt(opt, v)
}
-// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+// GetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
@@ -388,31 +336,24 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- case tcpip.TTLOption:
- e.rcvMu.Lock()
- v := int(e.ttl)
- e.rcvMu.Unlock()
- return v, nil
-
default:
- return -1, &tcpip.ErrUnknownProtocolOption{}
+ return e.net.GetSockOptInt(opt)
}
}
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error {
- return &tcpip.ErrUnknownProtocolOption{}
+// GetSockOpt implements tcpip.Endpoint.
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+ return e.net.GetSockOpt(opt)
}
-func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) tcpip.Error {
+func send4(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, maxHeaderLength uint16) tcpip.Error {
if len(data) < header.ICMPv4MinimumSize {
return &tcpip.ErrInvalidEndpointState{}
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.ICMPv4MinimumSize + int(r.MaxHeaderLength()),
+ ReserveHeaderBytes: header.ICMPv4MinimumSize + int(maxHeaderLength),
})
- pkt.Owner = owner
icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize))
pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
@@ -427,36 +368,31 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
return &tcpip.ErrInvalidEndpointState{}
}
- // Because this icmp endpoint is implemented in the transport layer, we can
- // only increment the 'stack-wide' stats but we can't increment the
- // 'per-NetworkEndpoint' stats.
- sentStat := r.Stats().ICMP.V4.PacketsSent.EchoRequest
-
icmpv4.SetChecksum(0)
icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
-
pkt.Data().AppendView(data)
- if ttl == 0 {
- ttl = r.DefaultTTL()
- }
+ // Because this icmp endpoint is implemented in the transport layer, we can
+ // only increment the 'stack-wide' stats but we can't increment the
+ // 'per-NetworkEndpoint' stats.
+ stats := s.Stats().ICMP.V4.PacketsSent
- if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil {
- r.Stats().ICMP.V4.PacketsSent.Dropped.Increment()
+ if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil {
+ stats.Dropped.Increment()
return err
}
- sentStat.Increment()
+ stats.EchoRequest.Increment()
return nil
}
-func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Error {
+func send6(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, src, dst tcpip.Address, maxHeaderLength uint16) tcpip.Error {
if len(data) < header.ICMPv6EchoMinimumSize {
return &tcpip.ErrInvalidEndpointState{}
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.ICMPv6MinimumSize + int(r.MaxHeaderLength()),
+ ReserveHeaderBytes: header.ICMPv6MinimumSize + int(maxHeaderLength),
})
icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize))
@@ -469,43 +405,31 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro
if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 {
return &tcpip.ErrInvalidEndpointState{}
}
- // Because this icmp endpoint is implemented in the transport layer, we can
- // only increment the 'stack-wide' stats but we can't increment the
- // 'per-NetworkEndpoint' stats.
- sentStat := r.Stats().ICMP.V6.PacketsSent.EchoRequest
pkt.Data().AppendView(data)
dataRange := pkt.Data().AsRange()
icmpv6.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpv6,
- Src: r.LocalAddress(),
- Dst: r.RemoteAddress(),
+ Src: src,
+ Dst: dst,
PayloadCsum: dataRange.Checksum(),
PayloadLen: dataRange.Size(),
}))
- if ttl == 0 {
- ttl = r.DefaultTTL()
- }
+ // Because this icmp endpoint is implemented in the transport layer, we can
+ // only increment the 'stack-wide' stats but we can't increment the
+ // 'per-NetworkEndpoint' stats.
+ stats := s.Stats().ICMP.V6.PacketsSent
- if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil {
- r.Stats().ICMP.V6.PacketsSent.Dropped.Increment()
+ if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil {
+ stats.Dropped.Increment()
+ return err
}
- sentStat.Increment()
+ stats.EchoRequest.Increment()
return nil
}
-// checkV4MappedLocked determines the effective network protocol and converts
-// addr to its canonical form.
-func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
- unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */)
- if err != nil {
- return tcpip.FullAddress{}, 0, err
- }
- return unwrapped, netProto, nil
-}
-
// Disconnect implements tcpip.Endpoint.Disconnect.
func (*endpoint) Disconnect() tcpip.Error {
return &tcpip.ErrNotSupported{}
@@ -516,59 +440,21 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- nicID := addr.NIC
- localPort := uint16(0)
- switch e.state {
- case stateInitial:
- case stateBound, stateConnected:
- localPort = e.ID.LocalPort
- if e.BindNICID == 0 {
- break
- }
+ err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error {
+ nextID.LocalPort = e.ident
- if nicID != 0 && nicID != e.BindNICID {
- return &tcpip.ErrInvalidEndpointState{}
+ nextID, err := e.registerWithStack(netProto, nextID)
+ if err != nil {
+ return err
}
- nicID = e.BindNICID
- default:
- return &tcpip.ErrInvalidEndpointState{}
- }
-
- addr, netProto, err := e.checkV4MappedLocked(addr)
- if err != nil {
- return err
- }
-
- // Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicID, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */)
- if err != nil {
- return err
- }
-
- id := stack.TransportEndpointID{
- LocalAddress: r.LocalAddress(),
- LocalPort: localPort,
- RemoteAddress: r.RemoteAddress(),
- }
-
- // Even if we're connected, this endpoint can still be used to send
- // packets on a different network protocol, so we register both even if
- // v6only is set to false and this is an ipv6 endpoint.
- netProtos := []tcpip.NetworkProtocolNumber{netProto}
-
- id, err = e.registerWithStack(nicID, netProtos, id)
+ e.ident = nextID.LocalPort
+ return nil
+ })
if err != nil {
- r.Release()
return err
}
- e.ID = id
- e.route = r
- e.RegisterNICID = nicID
-
- e.state = stateConnected
-
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
@@ -586,10 +472,19 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error {
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- e.shutdownFlags |= flags
- if e.state != stateConnected {
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
return &tcpip.ErrNotConnected{}
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
+ }
+
+ if flags&tcpip.ShutdownWrite != 0 {
+ if err := e.net.Shutdown(); err != nil {
+ return err
+ }
}
if flags&tcpip.ShutdownRead != 0 {
@@ -616,19 +511,18 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
return nil, nil, &tcpip.ErrNotSupported{}
}
-func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) {
+func (e *endpoint) registerWithStack(netProto tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) {
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice)
- return id, err
+ return id, e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice)
}
// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice)
+ err := e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice)
switch err.(type) {
case nil:
return true, nil
@@ -645,42 +539,27 @@ func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkPro
func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.state != stateInitial {
+ if e.net.State() != transport.DatagramEndpointStateInitial {
return &tcpip.ErrInvalidEndpointState{}
}
- addr, netProto, err := e.checkV4MappedLocked(addr)
- if err != nil {
- return err
- }
-
- // Expand netProtos to include v4 and v6 if the caller is binding to a
- // wildcard (empty) address, and this is an IPv6 endpoint with v6only
- // set to false.
- netProtos := []tcpip.NetworkProtocolNumber{netProto}
-
- if len(addr.Addr) != 0 {
- // A local address was specified, verify that it's valid.
- if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 {
- return &tcpip.ErrBadLocalAddress{}
+ err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error {
+ id := stack.TransportEndpointID{
+ LocalPort: addr.Port,
+ LocalAddress: addr.Addr,
+ }
+ id, err := e.registerWithStack(boundNetProto, id)
+ if err != nil {
+ return err
}
- }
- id := stack.TransportEndpointID{
- LocalPort: addr.Port,
- LocalAddress: addr.Addr,
- }
- id, err = e.registerWithStack(addr.NIC, netProtos, id)
+ e.ident = id.LocalPort
+ return nil
+ })
if err != nil {
return err
}
- e.ID = id
- e.RegisterNICID = addr.NIC
-
- // Mark endpoint as bound.
- e.state = stateBound
-
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
@@ -688,21 +567,24 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
return nil
}
+func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, addr tcpip.Address) bool {
+ return addr == header.IPv4Broadcast ||
+ header.IsV4MulticastAddress(addr) ||
+ header.IsV6MulticastAddress(addr) ||
+ e.stack.IsSubnetBroadcast(nicID, e.net.NetProto(), addr)
+}
+
// Bind binds the endpoint to a specific local address and port.
// Specifying a NIC is optional.
func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- err := e.bindLocked(addr)
- if err != nil {
- return err
+ if len(addr.Addr) != 0 && e.isBroadcastOrMulticast(addr.NIC, addr.Addr) {
+ return &tcpip.ErrBadLocalAddress{}
}
- e.BindNICID = addr.NIC
- e.BindAddr = addr.Addr
+ e.mu.Lock()
+ defer e.mu.Unlock()
- return nil
+ return e.bindLocked(addr)
}
// GetLocalAddress returns the address to which the endpoint is bound.
@@ -710,11 +592,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: e.ID.LocalAddress,
- Port: e.ID.LocalPort,
- }, nil
+ addr := e.net.GetLocalAddress()
+ addr.Port = e.ident
+ return addr, nil
}
// GetRemoteAddress returns the address to which the endpoint is connected.
@@ -722,15 +602,11 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.state != stateConnected {
- return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
+ if addr, connected := e.net.GetRemoteAddress(); connected {
+ return addr, nil
}
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: e.ID.RemoteAddress,
- Port: e.ID.RemotePort,
- }, nil
+ return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
}
// Readiness returns the current readiness of the endpoint. For example, if
@@ -755,7 +631,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// endpoint.
func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Only accept echo replies.
- switch e.NetProto {
+ switch e.net.NetProto() {
case header.IPv4ProtocolNumber:
h := header.ICMPv4(pkt.TransportHeader().View())
if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply {
@@ -829,9 +705,9 @@ func (e *endpoint) State() uint32 {
// Info returns a copy of the endpoint info.
func (e *endpoint) Info() tcpip.EndpointInfo {
e.mu.RLock()
- // Make a copy of the endpoint info.
- ret := e.TransportEndpointInfo
- e.mu.RUnlock()
+ defer e.mu.RUnlock()
+ ret := e.net.Info()
+ ret.ID.LocalPort = e.ident
return &ret
}
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
index b8b839e4a..dfe453ff9 100644
--- a/pkg/tcpip/transport/icmp/endpoint_state.go
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -15,11 +15,13 @@
package icmp
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
)
// saveReceivedAt is invoked by stateify.
@@ -61,29 +63,24 @@ func (e *endpoint) beforeSave() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
e.thaw()
+
+ e.net.Resume(s)
+
e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
- if e.state != stateBound && e.state != stateConnected {
- return
- }
-
- var err tcpip.Error
- if e.state == stateConnected {
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */)
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
+ case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
+ var err tcpip.Error
+ info := e.net.Info()
+ info.ID.LocalPort = e.ident
+ info.ID, err = e.registerWithStack(info.NetProto, info.ID)
if err != nil {
- panic(err)
+ panic(fmt.Sprintf("e.registerWithStack(%d, %#v): %s", info.NetProto, info.ID, err))
}
-
- e.ID.LocalAddress = e.route.LocalAddress()
- } else if len(e.ID.LocalAddress) != 0 { // stateBound
- if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 {
- panic(&tcpip.ErrBadLocalAddress{})
- }
- }
-
- e.ID, err = e.registerWithStack(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.ID)
- if err != nil {
- panic(err)
+ e.ident = info.ID.LocalPort
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
}
diff --git a/pkg/tcpip/transport/icmp/icmp_test.go b/pkg/tcpip/transport/icmp/icmp_test.go
index cc950cbde..729f50e9a 100644
--- a/pkg/tcpip/transport/icmp/icmp_test.go
+++ b/pkg/tcpip/transport/icmp/icmp_test.go
@@ -55,8 +55,12 @@ func addNICWithDefaultRoute(t *testing.T, s *stack.Stack, id tcpip.NICID, name s
t.Fatalf("s.CreateNIC(%d, _) = %s", id, err)
}
- if err := s.AddAddress(id, ipv4.ProtocolNumber, addrV4); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s) = %s", id, ipv4.ProtocolNumber, addrV4, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: addrV4.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err)
}
s.AddRoute(tcpip.Route{
diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD
index b1edce39b..3818cb04e 100644
--- a/pkg/tcpip/transport/internal/network/BUILD
+++ b/pkg/tcpip/transport/internal/network/BUILD
@@ -9,6 +9,7 @@ go_library(
"endpoint_state.go",
],
visibility = [
+ "//pkg/tcpip/transport/icmp:__pkg__",
"//pkg/tcpip/transport/raw:__pkg__",
"//pkg/tcpip/transport/udp:__pkg__",
],
diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go
index 09b629022..fb31e5104 100644
--- a/pkg/tcpip/transport/internal/network/endpoint.go
+++ b/pkg/tcpip/transport/internal/network/endpoint.go
@@ -38,31 +38,65 @@ type Endpoint struct {
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
- // state holds a transport.DatagramBasedEndpointState.
- //
- // state must be read from/written to atomically.
- state uint32
-
- // The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
wasBound bool
- info stack.TransportEndpointInfo
// owner is the owner of transmitted packets.
- owner tcpip.PacketOwner
- writeShutdown bool
- effectiveNetProto tcpip.NetworkProtocolNumber
- connectedRoute *stack.Route `state:"manual"`
+ //
+ // +checklocks:mu
+ owner tcpip.PacketOwner
+ // +checklocks:mu
+ writeShutdown bool
+ // +checklocks:mu
+ effectiveNetProto tcpip.NetworkProtocolNumber
+ // +checklocks:mu
+ connectedRoute *stack.Route `state:"manual"`
+ // +checklocks:mu
multicastMemberships map[multicastMembership]struct{}
// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
ttl uint8
// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
multicastTTL uint8
// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
multicastAddr tcpip.Address
// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
+ // +checklocks:mu
multicastNICID tcpip.NICID
- ipv4TOS uint8
- ipv6TClass uint8
+ // +checklocks:mu
+ ipv4TOS uint8
+ // +checklocks:mu
+ ipv6TClass uint8
+
+ // Lock ordering: mu > infoMu.
+ infoMu sync.RWMutex `state:"nosave"`
+ // info has a dedicated mutex so that we can avoid lock ordering violations
+ // when reading the endpoint's info. If we used mu, we need to guarantee
+ // that any lock taken while mu is held is not held when calling Info()
+ // which is not true as of writing (we hold mu while registering transport
+ // endpoints (taking the transport demuxer lock but we also hold the demuxer
+ // lock when delivering packets/errors to endpoints).
+ //
+ // Writes must be performed through setInfo.
+ //
+ // +checklocks:infoMu
+ info stack.TransportEndpointInfo
+
+ // state holds a transport.DatagramBasedEndpointState.
+ //
+ // state must be accessed with atomics so that we can avoid lock ordering
+ // violations when reading the state. If we used mu, we need to guarantee
+ // that any lock taken while mu is held is not held when calling State()
+ // which is not true as of writing (we hold mu while registering transport
+ // endpoints (taking the transport demuxer lock but we also hold the demuxer
+ // lock when delivering packets/errors to endpoints).
+ //
+ // Writes must be performed through setEndpointState.
+ //
+ // +checkatomics
+ state uint32
}
// +stateify savable
@@ -73,8 +107,11 @@ type multicastMembership struct {
// Init initializes the endpoint.
func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ops *tcpip.SocketOptions) {
- if e.multicastMemberships != nil {
- panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", e.multicastMemberships))
+ e.mu.Lock()
+ memberships := e.multicastMemberships
+ e.mu.Unlock()
+ if memberships != nil {
+ panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", memberships))
}
switch netProto {
@@ -89,8 +126,6 @@ func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, tr
netProto: netProto,
transProto: transProto,
- state: uint32(transport.DatagramEndpointStateInitial),
-
info: stack.TransportEndpointInfo{
NetProto: netProto,
TransProto: transProto,
@@ -100,6 +135,10 @@ func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, tr
multicastTTL: 1,
multicastMemberships: make(map[multicastMembership]struct{}),
}
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.setEndpointState(transport.DatagramEndpointStateInitial)
}
// NetProto returns the network protocol the endpoint was initialized with.
@@ -107,7 +146,12 @@ func (e *Endpoint) NetProto() tcpip.NetworkProtocolNumber {
return e.netProto
}
-// setState sets the state of the endpoint.
+// setEndpointState sets the state of the endpoint.
+//
+// e.mu must be held to synchronize changes to state with the rest of the
+// endpoint.
+//
+// +checklocks:e.mu
func (e *Endpoint) setEndpointState(state transport.DatagramEndpointState) {
atomic.StoreUint32(&e.state, uint32(state))
}
@@ -242,23 +286,24 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext
if nicID == 0 {
nicID = tcpip.NICID(e.ops.GetBindToDevice())
}
- if e.info.BindNICID != 0 {
- if nicID != 0 && nicID != e.info.BindNICID {
+ info := e.Info()
+ if info.BindNICID != 0 {
+ if nicID != 0 && nicID != info.BindNICID {
return WriteContext{}, &tcpip.ErrNoRoute{}
}
- nicID = e.info.BindNICID
+ nicID = info.BindNICID
}
if nicID == 0 {
- nicID = e.info.RegisterNICID
+ nicID = info.RegisterNICID
}
- dst, netProto, err := e.checkV4MappedLocked(*opts.To)
+ dst, netProto, err := e.checkV4Mapped(*opts.To)
if err != nil {
return WriteContext{}, err
}
- route, _, err = e.connectRoute(nicID, dst, netProto)
+ route, _, err = e.connectRouteRLocked(nicID, dst, netProto)
if err != nil {
return WriteContext{}, err
}
@@ -297,26 +342,30 @@ func (e *Endpoint) Disconnect() {
return
}
+ info := e.Info()
// Exclude ephemerally bound endpoints.
if e.wasBound {
- e.info.ID = stack.TransportEndpointID{
- LocalAddress: e.info.BindAddr,
+ info.ID = stack.TransportEndpointID{
+ LocalAddress: info.BindAddr,
}
e.setEndpointState(transport.DatagramEndpointStateBound)
} else {
- e.info.ID = stack.TransportEndpointID{}
+ info.ID = stack.TransportEndpointID{}
e.setEndpointState(transport.DatagramEndpointStateInitial)
}
+ e.setInfo(info)
e.connectedRoute.Release()
e.connectedRoute = nil
}
-// connectRoute establishes a route to the specified interface or the
+// connectRouteRLocked establishes a route to the specified interface or the
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
-func (e *Endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
- localAddr := e.info.ID.LocalAddress
+//
+// +checklocksread:e.mu
+func (e *Endpoint) connectRouteRLocked(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
+ localAddr := e.Info().ID.LocalAddress
if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
// A packet can only originate from a unicast address (i.e., an interface).
localAddr = ""
@@ -359,42 +408,43 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.
e.mu.Lock()
defer e.mu.Unlock()
+ info := e.Info()
nicID := addr.NIC
switch e.State() {
case transport.DatagramEndpointStateInitial:
case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
- if e.info.BindNICID == 0 {
+ if info.BindNICID == 0 {
break
}
- if nicID != 0 && nicID != e.info.BindNICID {
+ if nicID != 0 && nicID != info.BindNICID {
return &tcpip.ErrInvalidEndpointState{}
}
- nicID = e.info.BindNICID
+ nicID = info.BindNICID
default:
return &tcpip.ErrInvalidEndpointState{}
}
- addr, netProto, err := e.checkV4MappedLocked(addr)
+ addr, netProto, err := e.checkV4Mapped(addr)
if err != nil {
return err
}
- r, nicID, err := e.connectRoute(nicID, addr, netProto)
+ r, nicID, err := e.connectRouteRLocked(nicID, addr, netProto)
if err != nil {
return err
}
id := stack.TransportEndpointID{
- LocalAddress: e.info.ID.LocalAddress,
+ LocalAddress: info.ID.LocalAddress,
RemoteAddress: r.RemoteAddress(),
}
if e.State() == transport.DatagramEndpointStateInitial {
id.LocalAddress = r.LocalAddress()
}
- if err := f(r.NetProto(), e.info.ID, id); err != nil {
+ if err := f(r.NetProto(), info.ID, id); err != nil {
return err
}
@@ -403,8 +453,9 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.
e.connectedRoute.Release()
}
e.connectedRoute = r
- e.info.ID = id
- e.info.RegisterNICID = nicID
+ info.ID = id
+ info.RegisterNICID = nicID
+ e.setInfo(info)
e.effectiveNetProto = netProto
e.setEndpointState(transport.DatagramEndpointStateConnected)
return nil
@@ -426,10 +477,11 @@ func (e *Endpoint) Shutdown() tcpip.Error {
}
}
-// checkV4MappedLocked determines the effective network protocol and converts
+// checkV4MappedRLocked determines the effective network protocol and converts
// addr to its canonical form.
-func (e *Endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
- unwrapped, netProto, err := e.info.AddrNetProtoLocked(addr, e.ops.GetV6Only())
+func (e *Endpoint) checkV4Mapped(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
+ info := e.Info()
+ unwrapped, netProto, err := info.AddrNetProtoLocked(addr, e.ops.GetV6Only())
if err != nil {
return tcpip.FullAddress{}, 0, err
}
@@ -464,7 +516,7 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto
return &tcpip.ErrInvalidEndpointState{}
}
- addr, netProto, err := e.checkV4MappedLocked(addr)
+ addr, netProto, err := e.checkV4Mapped(addr)
if err != nil {
return err
}
@@ -483,12 +535,14 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto
e.wasBound = true
- e.info.ID = stack.TransportEndpointID{
+ info := e.Info()
+ info.ID = stack.TransportEndpointID{
LocalAddress: addr.Addr,
}
- e.info.BindNICID = addr.NIC
- e.info.RegisterNICID = nicID
- e.info.BindAddr = addr.Addr
+ info.BindNICID = addr.NIC
+ info.RegisterNICID = nicID
+ info.BindAddr = addr.Addr
+ e.setInfo(info)
e.effectiveNetProto = netProto
e.setEndpointState(transport.DatagramEndpointStateBound)
return nil
@@ -506,13 +560,14 @@ func (e *Endpoint) GetLocalAddress() tcpip.FullAddress {
e.mu.RLock()
defer e.mu.RUnlock()
- addr := e.info.BindAddr
+ info := e.Info()
+ addr := info.BindAddr
if e.State() == transport.DatagramEndpointStateConnected {
addr = e.connectedRoute.LocalAddress()
}
return tcpip.FullAddress{
- NIC: e.info.RegisterNICID,
+ NIC: info.RegisterNICID,
Addr: addr,
}
}
@@ -528,7 +583,7 @@ func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) {
return tcpip.FullAddress{
Addr: e.connectedRoute.RemoteAddress(),
- NIC: e.info.RegisterNICID,
+ NIC: e.Info().RegisterNICID,
}, true
}
@@ -610,7 +665,7 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
defer e.mu.Unlock()
fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
- fa, netProto, err := e.checkV4MappedLocked(fa)
+ fa, netProto, err := e.checkV4Mapped(fa)
if err != nil {
return err
}
@@ -634,7 +689,7 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
}
}
- if e.info.BindNICID != 0 && e.info.BindNICID != nic {
+ if info := e.Info(); info.BindNICID != 0 && info.BindNICID != nic {
return &tcpip.ErrInvalidEndpointState{}
}
@@ -737,7 +792,19 @@ func (e *Endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
// Info returns a copy of the endpoint info.
func (e *Endpoint) Info() stack.TransportEndpointInfo {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.infoMu.RLock()
+ defer e.infoMu.RUnlock()
return e.info
}
+
+// setInfo sets the endpoint's info.
+//
+// e.mu must be held to synchronize changes to info with the rest of the
+// endpoint.
+//
+// +checklocks:e.mu
+func (e *Endpoint) setInfo(info stack.TransportEndpointInfo) {
+ e.infoMu.Lock()
+ defer e.infoMu.Unlock()
+ e.info = info
+}
diff --git a/pkg/tcpip/transport/internal/network/endpoint_state.go b/pkg/tcpip/transport/internal/network/endpoint_state.go
index 858007156..68bd1fbf6 100644
--- a/pkg/tcpip/transport/internal/network/endpoint_state.go
+++ b/pkg/tcpip/transport/internal/network/endpoint_state.go
@@ -35,20 +35,22 @@ func (e *Endpoint) Resume(s *stack.Stack) {
}
}
+ info := e.Info()
+
switch state := e.State(); state {
case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
case transport.DatagramEndpointStateBound:
- if len(e.info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) {
- if e.stack.CheckLocalAddress(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) == 0 {
- panic(fmt.Sprintf("got e.stack.CheckLocalAddress(%d, %d, %s) = 0, want != 0", e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress))
+ if len(info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress) {
+ if e.stack.CheckLocalAddress(info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress) == 0 {
+ panic(fmt.Sprintf("got e.stack.CheckLocalAddress(%d, %d, %s) = 0, want != 0", info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress))
}
}
case transport.DatagramEndpointStateConnected:
var err tcpip.Error
multicastLoop := e.ops.GetMulticastLoop()
- e.connectedRoute, err = e.stack.FindRoute(e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop)
+ e.connectedRoute, err = e.stack.FindRoute(info.RegisterNICID, info.ID.LocalAddress, info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop)
if err != nil {
- panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err))
+ panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", info.RegisterNICID, info.ID.LocalAddress, info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err))
}
default:
panic(fmt.Sprintf("unhandled state = %s", state))
diff --git a/pkg/tcpip/transport/internal/network/endpoint_test.go b/pkg/tcpip/transport/internal/network/endpoint_test.go
index d99c961c3..f263a9ea2 100644
--- a/pkg/tcpip/transport/internal/network/endpoint_test.go
+++ b/pkg/tcpip/transport/internal/network/endpoint_test.go
@@ -124,11 +124,20 @@ func TestEndpointStateTransitions(t *testing.T) {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4NICAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4NICAddr, err)
+ ipv4ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ipv4NICAddr.WithPrefix(),
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6NICAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6NICAddr, err)
+ if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}: %s", nicID, ipv4ProtocolAddr, err)
+ }
+ ipv6ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: ipv6NICAddr.WithPrefix(),
+ }
+
+ if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -257,11 +266,19 @@ func TestBindNICID(t *testing.T) {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4NICAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4NICAddr, err)
+ ipv4ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: ipv4NICAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, ipv4ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv4ProtocolAddr, err)
+ }
+ ipv6ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: ipv6NICAddr.WithPrefix(),
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6NICAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6NICAddr, err)
+ if err := s.AddProtocolAddress(nicID, ipv6ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, ipv6ProtocolAddr, err)
}
var ops tcpip.SocketOptions
diff --git a/pkg/tcpip/transport/internal/noop/BUILD b/pkg/tcpip/transport/internal/noop/BUILD
new file mode 100644
index 000000000..171c41eb1
--- /dev/null
+++ b/pkg/tcpip/transport/internal/noop/BUILD
@@ -0,0 +1,14 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "noop",
+ srcs = ["endpoint.go"],
+ visibility = ["//pkg/tcpip/transport/raw:__pkg__"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/stack",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/tcpip/transport/internal/noop/endpoint.go b/pkg/tcpip/transport/internal/noop/endpoint.go
new file mode 100644
index 000000000..443b4e416
--- /dev/null
+++ b/pkg/tcpip/transport/internal/noop/endpoint.go
@@ -0,0 +1,172 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package noop contains an endpoint that implements all tcpip.Endpoint
+// functions as noops.
+package noop
+
+import (
+ "fmt"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// endpoint can be created, but all interactions have no effect or
+// return errors.
+//
+// +stateify savable
+type endpoint struct {
+ tcpip.DefaultSocketOptionsHandler
+ ops tcpip.SocketOptions
+}
+
+// New returns an initialized noop endpoint.
+func New(stk *stack.Stack) tcpip.Endpoint {
+ // ep.ops must be in a valid, initialized state for callers of
+ // ep.SocketOptions.
+ var ep endpoint
+ ep.ops.InitHandler(&ep, stk, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
+ return &ep
+}
+
+// Abort implements stack.TransportEndpoint.Abort.
+func (*endpoint) Abort() {
+ // No-op.
+}
+
+// Close implements tcpip.Endpoint.Close.
+func (*endpoint) Close() {
+ // No-op.
+}
+
+// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
+func (*endpoint) ModerateRecvBuf(int) {
+ // No-op.
+}
+
+func (*endpoint) SetOwner(tcpip.PacketOwner) {
+ // No-op.
+}
+
+// Read implements tcpip.Endpoint.Read.
+func (*endpoint) Read(io.Writer, tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
+ return tcpip.ReadResult{}, &tcpip.ErrNotPermitted{}
+}
+
+// Write implements tcpip.Endpoint.Write.
+func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, tcpip.Error) {
+ return 0, &tcpip.ErrNotPermitted{}
+}
+
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*endpoint) Disconnect() tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
+// Connect implements tcpip.Endpoint.Connect.
+func (*endpoint) Connect(tcpip.FullAddress) tcpip.Error {
+ return &tcpip.ErrNotPermitted{}
+}
+
+// Shutdown implements tcpip.Endpoint.Shutdown.
+func (*endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error {
+ return &tcpip.ErrNotPermitted{}
+}
+
+// Listen implements tcpip.Endpoint.Listen.
+func (*endpoint) Listen(int) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
+}
+
+// Accept implements tcpip.Endpoint.Accept.
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) {
+ return nil, nil, &tcpip.ErrNotSupported{}
+}
+
+// Bind implements tcpip.Endpoint.Bind.
+func (*endpoint) Bind(tcpip.FullAddress) tcpip.Error {
+ return &tcpip.ErrNotPermitted{}
+}
+
+// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
+func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
+ return tcpip.FullAddress{}, &tcpip.ErrNotSupported{}
+}
+
+// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
+func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
+ return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
+}
+
+// Readiness implements tcpip.Endpoint.Readiness.
+func (*endpoint) Readiness(waiter.EventMask) waiter.EventMask {
+ return 0
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
+func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
+}
+
+func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
+}
+
+// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+func (*endpoint) GetSockOptInt(tcpip.SockOptInt) (int, tcpip.Error) {
+ return 0, &tcpip.ErrUnknownProtocolOption{}
+}
+
+// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
+func (*endpoint) HandlePacket(pkt *stack.PacketBuffer) {
+ panic(fmt.Sprintf("unreachable: noop.endpoint should never be registered, but got packet: %+v", pkt))
+}
+
+// State implements socket.Socket.State.
+func (*endpoint) State() uint32 {
+ return 0
+}
+
+// Wait implements stack.TransportEndpoint.Wait.
+func (*endpoint) Wait() {
+ // No-op.
+}
+
+// LastError implements tcpip.Endpoint.LastError.
+func (*endpoint) LastError() tcpip.Error {
+ return nil
+}
+
+// SocketOptions implements tcpip.Endpoint.SocketOptions.
+func (ep *endpoint) SocketOptions() *tcpip.SocketOptions {
+ return &ep.ops
+}
+
+// Info implements tcpip.Endpoint.Info.
+func (*endpoint) Info() tcpip.EndpointInfo {
+ return &stack.TransportEndpointInfo{}
+}
+
+// Stats returns a pointer to the endpoint stats.
+func (*endpoint) Stats() tcpip.EndpointStats {
+ return &tcpip.TransportEndpointStats{}
+}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 0554d2f4a..80eef39e9 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -59,52 +59,47 @@ type packet struct {
//
// +stateify savable
type endpoint struct {
- stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
- netProto tcpip.NetworkProtocolNumber
waiterQueue *waiter.Queue
cooked bool
-
- // The following fields are used to manage the receive queue and are
- // protected by rcvMu.
- rcvMu sync.Mutex `state:"nosave"`
- rcvList packetList
+ ops tcpip.SocketOptions
+ stats tcpip.TransportEndpointStats
+
+ // The following fields are used to manage the receive queue.
+ rcvMu sync.Mutex `state:"nosave"`
+ // +checklocks:rcvMu
+ rcvList packetList
+ // +checklocks:rcvMu
rcvBufSize int
- rcvClosed bool
-
- // The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- closed bool
- stats tcpip.TransportEndpointStats `state:"nosave"`
- bound bool
+ // +checklocks:rcvMu
+ rcvClosed bool
+ // +checklocks:rcvMu
+ rcvDisabled bool
+
+ mu sync.RWMutex `state:"nosave"`
+ // +checklocks:mu
+ closed bool
+ // +checklocks:mu
+ boundNetProto tcpip.NetworkProtocolNumber
+ // +checklocks:mu
boundNIC tcpip.NICID
- // lastErrorMu protects lastError.
lastErrorMu sync.Mutex `state:"nosave"`
- lastError tcpip.Error
-
- // ops is used to get socket level options.
- ops tcpip.SocketOptions
-
- // frozen indicates if the packets should be delivered to the endpoint
- // during restore.
- frozen bool
+ // +checklocks:lastErrorMu
+ lastError tcpip.Error
}
// NewEndpoint returns a new packet endpoint.
func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
ep := &endpoint{
- stack: s,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- },
- cooked: cooked,
- netProto: netProto,
- waiterQueue: waiterQueue,
+ stack: s,
+ cooked: cooked,
+ boundNetProto: netProto,
+ waiterQueue: waiterQueue,
}
ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
ep.ops.SetReceiveBufferSize(32*1024, false /* notify */)
@@ -140,7 +135,7 @@ func (ep *endpoint) Close() {
return
}
- ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+ ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep)
ep.rcvMu.Lock()
defer ep.rcvMu.Unlock()
@@ -153,7 +148,6 @@ func (ep *endpoint) Close() {
}
ep.closed = true
- ep.bound = false
ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
@@ -188,7 +182,7 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul
Total: packet.data.Size(),
ControlMessages: tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: packet.receivedAt.UnixNano(),
+ Timestamp: packet.receivedAt,
},
}
if opts.NeedRemoteAddr {
@@ -214,13 +208,13 @@ func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tc
ep.mu.Lock()
closed := ep.closed
nicID := ep.boundNIC
+ proto := ep.boundNetProto
ep.mu.Unlock()
if closed {
return 0, &tcpip.ErrClosedForSend{}
}
var remote tcpip.LinkAddress
- proto := ep.netProto
if to := opts.To; to != nil {
remote = tcpip.LinkAddress(to.Addr)
@@ -296,29 +290,42 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- if ep.bound && ep.boundNIC == addr.NIC {
- // If the NIC being bound is the same then just return success.
+ netProto := tcpip.NetworkProtocolNumber(addr.Port)
+ if netProto == 0 {
+ // Do not allow unbinding the network protocol.
+ netProto = ep.boundNetProto
+ }
+
+ if ep.boundNIC == addr.NIC && ep.boundNetProto == netProto {
+ // Already bound to the requested NIC and network protocol.
return nil
}
- // Unregister endpoint with all the nics.
- ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
- ep.bound = false
+ // TODO(https://gvisor.dev/issue/6618): Unregister after registering the new
+ // binding.
+ ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep)
+ ep.boundNIC = 0
+ ep.boundNetProto = 0
// Bind endpoint to receive packets from specific interface.
- if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil {
+ if err := ep.stack.RegisterPacketEndpoint(addr.NIC, netProto, ep); err != nil {
return err
}
- ep.bound = true
ep.boundNIC = addr.NIC
-
+ ep.boundNetProto = netProto
return nil
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
-func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
- return tcpip.FullAddress{}, &tcpip.ErrNotSupported{}
+func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
+ ep.mu.RLock()
+ defer ep.mu.RUnlock()
+
+ return tcpip.FullAddress{
+ NIC: ep.boundNIC,
+ Port: uint16(ep.boundNetProto),
+ }, nil
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
@@ -402,7 +409,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
}
// HandlePacket implements stack.PacketEndpoint.HandlePacket.
-func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+func (ep *endpoint) HandlePacket(nicID tcpip.NICID, _ tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
ep.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
@@ -414,7 +421,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
}
rcvBufSize := ep.ops.GetReceiveBufferSize()
- if ep.frozen || ep.rcvBufSize >= int(rcvBufSize) {
+ if ep.rcvDisabled || ep.rcvBufSize >= int(rcvBufSize) {
ep.rcvMu.Unlock()
ep.stack.Stats().DroppedPackets.Increment()
ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
@@ -473,10 +480,8 @@ func (*endpoint) State() uint32 {
// Info returns a copy of the endpoint info.
func (ep *endpoint) Info() tcpip.EndpointInfo {
ep.mu.RLock()
- // Make a copy of the endpoint info.
- ret := ep.TransportEndpointInfo
- ep.mu.RUnlock()
- return &ret
+ defer ep.mu.RUnlock()
+ return &stack.TransportEndpointInfo{NetProto: ep.boundNetProto}
}
// Stats returns a pointer to the endpoint stats.
@@ -491,18 +496,3 @@ func (*endpoint) SetOwner(tcpip.PacketOwner) {}
func (ep *endpoint) SocketOptions() *tcpip.SocketOptions {
return &ep.ops
}
-
-// freeze prevents any more packets from being delivered to the endpoint.
-func (ep *endpoint) freeze() {
- ep.mu.Lock()
- ep.frozen = true
- ep.mu.Unlock()
-}
-
-// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows
-// new packets to be delivered again.
-func (ep *endpoint) thaw() {
- ep.mu.Lock()
- ep.frozen = false
- ep.mu.Unlock()
-}
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index 5c688d286..88cd80ad3 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -15,6 +15,7 @@
package packet
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -44,17 +45,24 @@ func (p *packet) loadData(data buffer.VectorisedView) {
// beforeSave is invoked by stateify.
func (ep *endpoint) beforeSave() {
- ep.freeze()
+ ep.rcvMu.Lock()
+ defer ep.rcvMu.Unlock()
+ ep.rcvDisabled = true
}
// afterLoad is invoked by stateify.
func (ep *endpoint) afterLoad() {
- ep.thaw()
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
ep.stack = stack.StackFromEnv
ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
- // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC.
- if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil {
- panic(err)
+ if err := ep.stack.RegisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep); err != nil {
+ panic(fmt.Sprintf("RegisterPacketEndpoint(%d, %d, _): %s", ep.boundNIC, ep.boundNetProto, err))
}
+
+ ep.rcvMu.Lock()
+ ep.rcvDisabled = false
+ ep.rcvMu.Unlock()
}
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index b7e97e218..10b0c35fb 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -35,6 +35,7 @@ go_library(
"//pkg/tcpip/stack",
"//pkg/tcpip/transport",
"//pkg/tcpip/transport/internal/network",
+ "//pkg/tcpip/transport/internal/noop",
"//pkg/tcpip/transport/packet",
"//pkg/waiter",
],
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 3040a445b..ce76774af 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -49,6 +49,7 @@ type rawPacket struct {
receivedAt time.Time `state:".(int64)"`
// senderAddr is the network address of the sender.
senderAddr tcpip.FullAddress
+ packetInfo tcpip.IPPacketInfo
}
// endpoint is the raw socket implementation of tcpip.Endpoint. It is legal to
@@ -70,7 +71,7 @@ type endpoint struct {
associated bool
net network.Endpoint
- stats tcpip.TransportEndpointStats `state:"nosave"`
+ stats tcpip.TransportEndpointStats
ops tcpip.SocketOptions
// The following fields are used to manage the receive queue and are
@@ -202,12 +203,29 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
Total: pkt.data.Size(),
ControlMessages: tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: pkt.receivedAt.UnixNano(),
+ Timestamp: pkt.receivedAt,
},
}
if opts.NeedRemoteAddr {
res.RemoteAddr = pkt.senderAddr
}
+ switch netProto := e.net.NetProto(); netProto {
+ case header.IPv4ProtocolNumber:
+ if e.ops.GetReceivePacketInfo() {
+ res.ControlMessages.HasIPPacketInfo = true
+ res.ControlMessages.PacketInfo = pkt.packetInfo
+ }
+ case header.IPv6ProtocolNumber:
+ if e.ops.GetIPv6ReceivePacketInfo() {
+ res.ControlMessages.HasIPv6PacketInfo = true
+ res.ControlMessages.IPv6PacketInfo = tcpip.IPv6PacketInfo{
+ NIC: pkt.packetInfo.NIC,
+ Addr: pkt.packetInfo.DestinationAddr,
+ }
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized network protocol = %d", netProto))
+ }
n, err := pkt.data.ReadTo(dst, opts.Peek)
if n == 0 && err != nil {
@@ -435,7 +453,9 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return false
}
- srcAddr := pkt.Network().SourceAddress()
+ net := pkt.Network()
+ dstAddr := net.DestinationAddress()
+ srcAddr := net.SourceAddress()
info := e.net.Info()
switch state := e.net.State(); state {
@@ -457,7 +477,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
// If bound to an address, only accept data for that address.
- if info.BindAddr != "" && info.BindAddr != pkt.Network().DestinationAddress() {
+ if info.BindAddr != "" && info.BindAddr != dstAddr {
return false
}
default:
@@ -472,6 +492,14 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
NIC: pkt.NICID,
Addr: srcAddr,
},
+ packetInfo: tcpip.IPPacketInfo{
+ // TODO(gvisor.dev/issue/3556): dstAddr may be a multicast or broadcast
+ // address. LocalAddr should hold a unicast address that can be
+ // used to respond to the incoming packet.
+ LocalAddr: dstAddr,
+ DestinationAddr: dstAddr,
+ NIC: pkt.NICID,
+ },
}
// Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not.
@@ -483,10 +511,10 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// overlapping slices.
var combinedVV buffer.VectorisedView
if info.NetProto == header.IPv4ProtocolNumber {
- network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
- headers := make(buffer.View, 0, len(network)+len(transport))
- headers = append(headers, network...)
- headers = append(headers, transport...)
+ networkHeader, transportHeader := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
+ headers := make(buffer.View, 0, len(networkHeader)+len(transportHeader))
+ headers = append(headers, networkHeader...)
+ headers = append(headers, transportHeader...)
combinedVV = headers.ToVectorisedView()
} else {
combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView()
diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go
index e393b993d..624e2dbe7 100644
--- a/pkg/tcpip/transport/raw/protocol.go
+++ b/pkg/tcpip/transport/raw/protocol.go
@@ -17,6 +17,7 @@ package raw
import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/internal/noop"
"gvisor.dev/gvisor/pkg/tcpip/transport/packet"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -33,3 +34,18 @@ func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpi
func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
return packet.NewEndpoint(stack, cooked, netProto, waiterQueue)
}
+
+// CreateOnlyFactory implements stack.RawFactory. It allows creation of raw
+// endpoints that do not support reading, writing, binding, etc.
+type CreateOnlyFactory struct{}
+
+// NewUnassociatedEndpoint implements stack.RawFactory.NewUnassociatedEndpoint.
+func (CreateOnlyFactory) NewUnassociatedEndpoint(stk *stack.Stack, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
+ return noop.New(stk), nil
+}
+
+// NewPacketEndpoint implements stack.RawFactory.NewPacketEndpoint.
+func (CreateOnlyFactory) NewPacketEndpoint(*stack.Stack, bool, tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
+ // This isn't needed by anything, so it isn't implemented.
+ return nil, &tcpip.ErrNotPermitted{}
+}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 5148fe157..20958d882 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -80,9 +80,10 @@ go_library(
go_test(
name = "tcp_x_test",
- size = "medium",
+ size = "large",
srcs = [
"dual_stack_test.go",
+ "rcv_test.go",
"sack_scoreboard_test.go",
"tcp_noracedetector_test.go",
"tcp_rack_test.go",
@@ -114,16 +115,6 @@ go_test(
)
go_test(
- name = "rcv_test",
- size = "small",
- srcs = ["rcv_test.go"],
- deps = [
- "//pkg/tcpip/header",
- "//pkg/tcpip/seqnum",
- ],
-)
-
-go_test(
name = "tcp_test",
size = "small",
srcs = [
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 03c9fafa1..caf14b0dc 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -15,12 +15,12 @@
package tcp
import (
+ "container/list"
"crypto/sha1"
"encoding/binary"
"fmt"
"hash"
"io"
- "sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/sleep"
@@ -100,18 +100,6 @@ type listenContext struct {
// netProto indicates the network protocol(IPv4/v6) for the listening
// endpoint.
netProto tcpip.NetworkProtocolNumber
-
- // pendingMu protects pendingEndpoints. This should only be accessed
- // by the listening endpoint's worker goroutine.
- //
- // Lock Ordering: listenEP.workerMu -> pendingMu
- pendingMu sync.Mutex
- // pending is used to wait for all pendingEndpoints to finish when
- // a socket is closed.
- pending sync.WaitGroup
- // pendingEndpoints is a map of all endpoints for which a handshake is
- // in progress.
- pendingEndpoints map[stack.TransportEndpointID]*endpoint
}
// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds.
@@ -122,14 +110,13 @@ func timeStamp(clock tcpip.Clock) uint32 {
// newListenContext creates a new listen context.
func newListenContext(stk *stack.Stack, protocol *protocol, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
- stack: stk,
- protocol: protocol,
- rcvWnd: rcvWnd,
- hasher: sha1.New(),
- v6Only: v6Only,
- netProto: netProto,
- listenEP: listenEP,
- pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint),
+ stack: stk,
+ protocol: protocol,
+ rcvWnd: rcvWnd,
+ hasher: sha1.New(),
+ v6Only: v6Only,
+ netProto: netProto,
+ listenEP: listenEP,
}
for i := range l.nonce {
@@ -193,14 +180,6 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true
}
-func (l *listenContext) useSynCookies() bool {
- var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies
- if err := l.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil {
- panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err))
- }
- return bool(alwaysUseSynCookies) || (l.listenEP != nil && l.listenEP.synRcvdBacklogFull())
-}
-
// createConnectingEndpoint creates a new endpoint in a connecting state, with
// the connection parameters given by the arguments.
func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) {
@@ -273,18 +252,15 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu
return nil, &tcpip.ErrConnectionAborted{}
}
- l.addPendingEndpoint(ep)
// Propagate any inheritable options from the listening endpoint
// to the newly created endpoint.
- l.listenEP.propagateInheritableOptionsLocked(ep)
+ l.listenEP.propagateInheritableOptionsLocked(ep) // +checklocksforce
if !ep.reserveTupleLocked() {
ep.mu.Unlock()
ep.Close()
- l.removePendingEndpoint(ep)
-
return nil, &tcpip.ErrConnectionAborted{}
}
@@ -303,10 +279,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu
ep.mu.Unlock()
ep.Close()
- if l.listenEP != nil {
- l.removePendingEndpoint(ep)
- }
-
ep.drainClosingSegmentQueue()
return nil, err
@@ -344,39 +316,12 @@ func (l *listenContext) performHandshake(s *segment, opts header.TCPSynOptions,
return ep, nil
}
-func (l *listenContext) addPendingEndpoint(n *endpoint) {
- l.pendingMu.Lock()
- l.pendingEndpoints[n.TransportEndpointInfo.ID] = n
- l.pending.Add(1)
- l.pendingMu.Unlock()
-}
-
-func (l *listenContext) removePendingEndpoint(n *endpoint) {
- l.pendingMu.Lock()
- delete(l.pendingEndpoints, n.TransportEndpointInfo.ID)
- l.pending.Done()
- l.pendingMu.Unlock()
-}
-
-func (l *listenContext) closeAllPendingEndpoints() {
- l.pendingMu.Lock()
- for _, n := range l.pendingEndpoints {
- n.notifyProtocolGoroutine(notifyClose)
- }
- l.pendingMu.Unlock()
- l.pending.Wait()
-}
-
-// Precondition: h.ep.mu must be held.
// +checklocks:h.ep.mu
func (l *listenContext) cleanupFailedHandshake(h *handshake) {
e := h.ep
e.mu.Unlock()
e.Close()
e.notifyAborted()
- if l.listenEP != nil {
- l.removePendingEndpoint(e)
- }
e.drainClosingSegmentQueue()
e.h = nil
}
@@ -384,12 +329,9 @@ func (l *listenContext) cleanupFailedHandshake(h *handshake) {
// cleanupCompletedHandshake transfers any state from the completed handshake to
// the new endpoint.
//
-// Precondition: h.ep.mu must be held.
+// +checklocks:h.ep.mu
func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
e := h.ep
- if l.listenEP != nil {
- l.removePendingEndpoint(e)
- }
e.isConnectNotified = true
// Update the receive window scaling. We can't do it before the
@@ -401,47 +343,11 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
e.h = nil
}
-// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
-// listener has transitioned out of the listen state (accepted is the zero
-// value), the new endpoint is reset instead.
-func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) {
- e.mu.Lock()
- e.pendingAccepted.Add(1)
- e.mu.Unlock()
- defer e.pendingAccepted.Done()
-
- // Drop the lock before notifying to avoid deadlock in user-specified
- // callbacks.
- delivered := func() bool {
- e.acceptMu.Lock()
- defer e.acceptMu.Unlock()
- for {
- if e.accepted == (accepted{}) {
- return false
- }
- if e.accepted.endpoints.Len() == e.accepted.cap {
- e.acceptCond.Wait()
- continue
- }
-
- e.accepted.endpoints.PushBack(n)
- if !withSynCookie {
- atomic.AddInt32(&e.synRcvdCount, -1)
- }
- return true
- }
- }()
- if delivered {
- e.waiterQueue.Notify(waiter.ReadableEvents)
- } else {
- n.notifyProtocolGoroutine(notifyReset)
- }
-}
-
// propagateInheritableOptionsLocked propagates any options set on the listening
// endpoint to the newly created endpoint.
//
-// Precondition: e.mu and n.mu must be held.
+// +checklocks:e.mu
+// +checklocks:n.mu
func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
n.userTimeout = e.userTimeout
n.portFlags = e.portFlags
@@ -452,9 +358,9 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
// reserveTupleLocked reserves an accepted endpoint's tuple.
//
-// Preconditions:
-// * propagateInheritableOptionsLocked has been called.
-// * e.mu is held.
+// Precondition: e.propagateInheritableOptionsLocked has been called.
+//
+// +checklocks:e.mu
func (e *endpoint) reserveTupleLocked() bool {
dest := tcpip.FullAddress{
Addr: e.TransportEndpointInfo.ID.RemoteAddress,
@@ -489,70 +395,36 @@ func (e *endpoint) notifyAborted() {
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
-// handleSynSegment is called in its own goroutine once the listening endpoint
-// receives a SYN segment. It is responsible for completing the handshake and
-// queueing the new endpoint for acceptance.
-//
-// A limited number of these goroutines are allowed before TCP starts using SYN
-// cookies to accept connections.
-//
-// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
-func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts header.TCPSynOptions) tcpip.Error {
- defer s.decRef()
+func (e *endpoint) acceptQueueIsFull() bool {
+ e.acceptMu.Lock()
+ full := e.acceptQueue.isFull()
+ e.acceptMu.Unlock()
+ return full
+}
- h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner)
- if err != nil {
- e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
- e.stats.FailedConnectionAttempts.Increment()
- atomic.AddInt32(&e.synRcvdCount, -1)
- return err
- }
+// +stateify savable
+type acceptQueue struct {
+ // NB: this could be an endpointList, but ilist only permits endpoints to
+ // belong to one list at a time, and endpoints are already stored in the
+ // dispatcher's list.
+ endpoints list.List `state:".([]*endpoint)"`
- go func() {
- // Note that startHandshake returns a locked endpoint. The
- // force call here just makes it so.
- if err := h.complete(); err != nil { // +checklocksforce
- e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
- e.stats.FailedConnectionAttempts.Increment()
- ctx.cleanupFailedHandshake(h)
- atomic.AddInt32(&e.synRcvdCount, -1)
- return
- }
- ctx.cleanupCompletedHandshake(h)
- h.ep.startAcceptedLoop()
- e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- e.deliverAccepted(h.ep, false /*withSynCookie*/)
- }()
+ // pendingEndpoints is a set of all endpoints for which a handshake is
+ // in progress.
+ pendingEndpoints map[*endpoint]struct{}
- return nil
+ // capacity is the maximum number of endpoints that can be in endpoints.
+ capacity int
}
-func (e *endpoint) synRcvdBacklogFull() bool {
- e.acceptMu.Lock()
- acceptedCap := e.accepted.cap
- e.acceptMu.Unlock()
- // The capacity of the accepted queue would always be one greater than the
- // listen backlog. But, the SYNRCVD connections count is always checked
- // against the listen backlog value for Linux parity reason.
- // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
- //
- // We maintain an equality check here as the synRcvdCount is incremented
- // and compared only from a single listener context and the capacity of
- // the accepted queue can only increase by a new listen call.
- return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedCap-1
-}
-
-func (e *endpoint) acceptQueueIsFull() bool {
- e.acceptMu.Lock()
- full := e.accepted != (accepted{}) && e.accepted.endpoints.Len() == e.accepted.cap
- e.acceptMu.Unlock()
- return full
+func (a *acceptQueue) isFull() bool {
+ return a.endpoints.Len() == a.capacity
}
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
//
-// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
+// +checklocks:e.mu
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Error {
e.rcvQueueInfo.rcvQueueMu.Lock()
rcvClosed := e.rcvQueueInfo.RcvClosed
@@ -580,11 +452,95 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
}
opts := parseSynSegmentOptions(s)
- if !ctx.useSynCookies() {
- s.incRef()
- atomic.AddInt32(&e.synRcvdCount, 1)
- return e.handleSynSegment(ctx, s, opts)
+
+ useSynCookies, err := func() (bool, tcpip.Error) {
+ var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies
+ if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil {
+ panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err))
+ }
+ if alwaysUseSynCookies {
+ return true, nil
+ }
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+
+ // The capacity of the accepted queue would always be one greater than the
+ // listen backlog. But, the SYNRCVD connections count is always checked
+ // against the listen backlog value for Linux parity reason.
+ // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
+ if len(e.acceptQueue.pendingEndpoints) == e.acceptQueue.capacity-1 {
+ return true, nil
+ }
+
+ h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner)
+ if err != nil {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ return false, err
+ }
+
+ e.acceptQueue.pendingEndpoints[h.ep] = struct{}{}
+ e.pendingAccepted.Add(1)
+
+ go func() {
+ defer func() {
+ e.pendingAccepted.Done()
+
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+ delete(e.acceptQueue.pendingEndpoints, h.ep)
+ }()
+
+ // Note that startHandshake returns a locked endpoint. The force call
+ // here just makes it so.
+ if err := h.complete(); err != nil { // +checklocksforce
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ ctx.cleanupFailedHandshake(h)
+ return
+ }
+ ctx.cleanupCompletedHandshake(h)
+ h.ep.startAcceptedLoop()
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
+
+ // Deliver the endpoint to the accept queue.
+ //
+ // Drop the lock before notifying to avoid deadlock in user-specified
+ // callbacks.
+ delivered := func() bool {
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+ for {
+ // The listener is transitioning out of the Listen state; bail.
+ if e.acceptQueue.capacity == 0 {
+ return false
+ }
+ if e.acceptQueue.isFull() {
+ e.acceptCond.Wait()
+ continue
+ }
+
+ e.acceptQueue.endpoints.PushBack(h.ep)
+ return true
+ }
+ }()
+
+ if delivered {
+ e.waiterQueue.Notify(waiter.ReadableEvents)
+ } else {
+ h.ep.notifyProtocolGoroutine(notifyReset)
+ }
+ }()
+
+ return false, nil
+ }()
+ if err != nil {
+ return err
+ }
+ if !useSynCookies {
+ return nil
}
+
route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
if err != nil {
return err
@@ -627,18 +583,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
return nil
case s.flags.Contains(header.TCPFlagAck):
- if e.acceptQueueIsFull() {
- // Silently drop the ack as the application can't accept
- // the connection at this point. The ack will be
- // retransmitted by the sender anyway and we can
- // complete the connection at the time of retransmit if
- // the backlog has space.
- e.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
- e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment()
- e.stack.Stats().DroppedPackets.Increment()
- return nil
- }
-
iss := s.ackNumber - 1
irs := s.sequenceNumber - 1
@@ -674,6 +618,24 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
// ACK was received from the sender.
return replyWithReset(e.stack, s, e.sendTOS, e.ttl)
}
+
+ // Keep hold of acceptMu until the new endpoint is in the accept queue (or
+ // if there is an error), to guarantee that we will keep our spot in the
+ // queue even if another handshake from the syn queue completes.
+ e.acceptMu.Lock()
+ if e.acceptQueue.isFull() {
+ // Silently drop the ack as the application can't accept
+ // the connection at this point. The ack will be
+ // retransmitted by the sender anyway and we can
+ // complete the connection at the time of retransmit if
+ // the backlog has space.
+ e.acceptMu.Unlock()
+ e.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
+ e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment()
+ e.stack.Stats().DroppedPackets.Increment()
+ return nil
+ }
+
e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment()
// Create newly accepted endpoint and deliver it.
rcvdSynOptions := header.TCPSynOptions{
@@ -695,6 +657,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
n, err := ctx.createConnectingEndpoint(s, rcvdSynOptions, &waiter.Queue{})
if err != nil {
+ e.acceptMu.Unlock()
return err
}
@@ -706,6 +669,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
if !n.reserveTupleLocked() {
n.mu.Unlock()
+ e.acceptMu.Unlock()
n.Close()
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
@@ -723,6 +687,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
n.boundBindToDevice,
); err != nil {
n.mu.Unlock()
+ e.acceptMu.Unlock()
n.Close()
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
@@ -755,20 +720,15 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
n.newSegmentWaker.Assert()
}
- // Do the delivery in a separate goroutine so
- // that we don't block the listen loop in case
- // the application is slow to accept or stops
- // accepting.
- //
- // NOTE: This won't result in an unbounded
- // number of goroutines as we do check before
- // entering here that there was at least some
- // space available in the backlog.
-
// Start the protocol goroutine.
n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- go e.deliverAccepted(n, true /*withSynCookie*/)
+
+ // Deliver the endpoint to the accept queue.
+ e.acceptQueue.endpoints.PushBack(n)
+ e.acceptMu.Unlock()
+
+ e.waiterQueue.Notify(waiter.ReadableEvents)
return nil
default:
@@ -785,14 +745,8 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
ctx := newListenContext(e.stack, e.protocol, e, rcvWnd, v6Only, e.NetProto)
defer func() {
- // Mark endpoint as closed. This will prevent goroutines running
- // handleSynSegment() from attempting to queue new connections
- // to the endpoint.
e.setEndpointState(StateClose)
- // Close any endpoints in SYN-RCVD state.
- ctx.closeAllPendingEndpoints()
-
// Do cleanup if needed.
e.completeWorkerLocked()
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 5d8e18484..80cd07218 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -30,6 +30,10 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// InitialRTO is the initial retransmission timeout.
+// https://github.com/torvalds/linux/blob/7c636d4d20f/include/net/tcp.h#L142
+const InitialRTO = time.Second
+
// maxSegmentsPerWake is the maximum number of segments to process in the main
// protocol goroutine per wake-up. Yielding [after this number of segments are
// processed] allows other events to be processed as well (e.g., timeouts,
@@ -532,7 +536,7 @@ func (h *handshake) complete() tcpip.Error {
defer s.Done()
// Initialize the resend timer.
- timer, err := newBackoffTimer(h.ep.stack.Clock(), time.Second, MaxRTO, resendWaker.Assert)
+ timer, err := newBackoffTimer(h.ep.stack.Clock(), InitialRTO, MaxRTO, resendWaker.Assert)
if err != nil {
return err
}
@@ -578,6 +582,9 @@ func (h *handshake) complete() tcpip.Error {
if (n&notifyClose)|(n&notifyAbort) != 0 {
return &tcpip.ErrAborted{}
}
+ if n&notifyShutdown != 0 {
+ return &tcpip.ErrConnectionReset{}
+ }
if n&notifyDrain != 0 {
for !h.ep.segmentQueue.empty() {
s := h.ep.segmentQueue.dequeue()
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index d2b8f298f..066ffe051 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -15,7 +15,6 @@
package tcp
import (
- "container/list"
"encoding/binary"
"fmt"
"io"
@@ -187,6 +186,8 @@ const (
// say TIME_WAIT.
notifyTickleWorker
notifyError
+ // notifyShutdown means that a connecting socket was shutdown.
+ notifyShutdown
)
// SACKInfo holds TCP SACK related information for a given endpoint.
@@ -203,6 +204,8 @@ type SACKInfo struct {
}
// ReceiveErrors collect segment receive errors within transport layer.
+//
+// +stateify savable
type ReceiveErrors struct {
tcpip.ReceiveErrors
@@ -232,6 +235,8 @@ type ReceiveErrors struct {
}
// SendErrors collect segment send errors within the transport layer.
+//
+// +stateify savable
type SendErrors struct {
tcpip.SendErrors
@@ -255,6 +260,8 @@ type SendErrors struct {
}
// Stats holds statistics about the endpoint.
+//
+// +stateify savable
type Stats struct {
// SegmentsReceived is the number of TCP segments received that
// the transport layer successfully parsed.
@@ -309,15 +316,6 @@ type rcvQueueInfo struct {
rcvQueue segmentList `state:"wait"`
}
-// +stateify savable
-type accepted struct {
- // NB: this could be an endpointList, but ilist only permits endpoints to
- // belong to one list at a time, and endpoints are already stored in the
- // dispatcher's list.
- endpoints list.List `state:".([]*endpoint)"`
- cap int
-}
-
// endpoint represents a TCP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -333,7 +331,7 @@ type accepted struct {
// The following three mutexes can be acquired independent of e.mu but if
// acquired with e.mu then e.mu must be acquired first.
//
-// e.acceptMu -> protects accepted.
+// e.acceptMu -> Protects e.acceptQueue.
// e.rcvQueueMu -> Protects e.rcvQueue and associated fields.
// e.sndQueueMu -> Protects the e.sndQueue and associated fields.
// e.lastErrorMu -> Protects the lastError field.
@@ -497,10 +495,6 @@ type endpoint struct {
// and dropped when it is.
segmentQueue segmentQueue `state:"wait"`
- // synRcvdCount is the number of connections for this endpoint that are
- // in SYN-RCVD state; this is only accessed atomically.
- synRcvdCount int32
-
// userMSS if non-zero is the MSS value explicitly set by the user
// for this endpoint using the TCP_MAXSEG setsockopt.
userMSS uint16
@@ -573,7 +567,8 @@ type endpoint struct {
// accepted is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
- accepted accepted
+ // +checklocks:acceptMu
+ acceptQueue acceptQueue
// The following are only used from the protocol goroutine, and
// therefore don't need locks to protect them.
@@ -606,8 +601,7 @@ type endpoint struct {
gso stack.GSO
- // TODO(b/142022063): Add ability to save and restore per endpoint stats.
- stats Stats `state:"nosave"`
+ stats Stats
// tcpLingerTimeout is the maximum amount of a time a socket
// a socket stays in TIME_WAIT state before being marked
@@ -819,10 +813,9 @@ func newEndpoint(s *stack.Stack, protocol *protocol, netProto tcpip.NetworkProto
waiterQueue: waiterQueue,
state: uint32(StateInitial),
keepalive: keepalive{
- // Linux defaults.
- idle: 2 * time.Hour,
- interval: 75 * time.Second,
- count: 9,
+ idle: DefaultKeepaliveIdle,
+ interval: DefaultKeepaliveInterval,
+ count: DefaultKeepaliveCount,
},
uniqueID: s.UniqueID(),
txHash: s.Rand().Uint32(),
@@ -904,7 +897,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// Check if there's anything in the accepted queue.
if (mask & waiter.ReadableEvents) != 0 {
e.acceptMu.Lock()
- if e.accepted.endpoints.Len() != 0 {
+ if e.acceptQueue.endpoints.Len() != 0 {
result |= waiter.ReadableEvents
}
e.acceptMu.Unlock()
@@ -1087,20 +1080,20 @@ func (e *endpoint) closeNoShutdownLocked() {
// handshake but not yet been delivered to the application.
func (e *endpoint) closePendingAcceptableConnectionsLocked() {
e.acceptMu.Lock()
- acceptedCopy := e.accepted
- e.accepted = accepted{}
- e.acceptMu.Unlock()
-
- if acceptedCopy == (accepted{}) {
- return
+ // Close any endpoints in SYN-RCVD state.
+ for n := range e.acceptQueue.pendingEndpoints {
+ n.notifyProtocolGoroutine(notifyClose)
}
-
- e.acceptCond.Broadcast()
-
+ e.acceptQueue.pendingEndpoints = nil
// Reset all connections that are waiting to be accepted.
- for n := acceptedCopy.endpoints.Front(); n != nil; n = n.Next() {
+ for n := e.acceptQueue.endpoints.Front(); n != nil; n = n.Next() {
n.Value.(*endpoint).notifyProtocolGoroutine(notifyReset)
}
+ e.acceptQueue.endpoints.Init()
+ e.acceptMu.Unlock()
+
+ e.acceptCond.Broadcast()
+
// Wait for reset of all endpoints that are still waiting to be delivered to
// the now closed accepted.
e.pendingAccepted.Wait()
@@ -2060,7 +2053,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
case *tcpip.OriginalDestinationOption:
e.LockUser()
ipt := e.stack.IPTables()
- addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto)
+ addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto, ProtocolNumber)
e.UnlockUser()
if err != nil {
return err
@@ -2380,6 +2373,18 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error {
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
e.LockUser()
defer e.UnlockUser()
+
+ if e.EndpointState().connecting() {
+ // When calling shutdown(2) on a connecting socket, the endpoint must
+ // enter the error state. But this logic cannot belong to the shutdownLocked
+ // method because that method is called during a close(2) (and closing a
+ // connecting socket is not an error).
+ e.resetConnectionLocked(&tcpip.ErrConnectionReset{})
+ e.notifyProtocolGoroutine(notifyShutdown)
+ e.waiterQueue.Notify(waiter.WritableEvents | waiter.EventHUp | waiter.EventErr)
+ return nil
+ }
+
return e.shutdownLocked(flags)
}
@@ -2480,22 +2485,23 @@ func (e *endpoint) listen(backlog int) tcpip.Error {
if e.EndpointState() == StateListen && !e.closed {
e.acceptMu.Lock()
defer e.acceptMu.Unlock()
- if e.accepted == (accepted{}) {
- // listen is called after shutdown.
- e.accepted.cap = backlog
- e.shutdownFlags = 0
- e.rcvQueueInfo.rcvQueueMu.Lock()
- e.rcvQueueInfo.RcvClosed = false
- e.rcvQueueInfo.rcvQueueMu.Unlock()
- } else {
- // Adjust the size of the backlog iff we can fit
- // existing pending connections into the new one.
- if e.accepted.endpoints.Len() > backlog {
- return &tcpip.ErrInvalidEndpointState{}
- }
- e.accepted.cap = backlog
+
+ // Adjust the size of the backlog iff we can fit
+ // existing pending connections into the new one.
+ if e.acceptQueue.endpoints.Len() > backlog {
+ return &tcpip.ErrInvalidEndpointState{}
+ }
+ e.acceptQueue.capacity = backlog
+
+ if e.acceptQueue.pendingEndpoints == nil {
+ e.acceptQueue.pendingEndpoints = make(map[*endpoint]struct{})
}
+ e.shutdownFlags = 0
+ e.rcvQueueInfo.rcvQueueMu.Lock()
+ e.rcvQueueInfo.RcvClosed = false
+ e.rcvQueueInfo.rcvQueueMu.Unlock()
+
// Notify any blocked goroutines that they can attempt to
// deliver endpoints again.
e.acceptCond.Broadcast()
@@ -2530,8 +2536,11 @@ func (e *endpoint) listen(backlog int) tcpip.Error {
// may be pre-populated with some previously accepted (but not Accepted)
// endpoints.
e.acceptMu.Lock()
- if e.accepted == (accepted{}) {
- e.accepted.cap = backlog
+ if e.acceptQueue.pendingEndpoints == nil {
+ e.acceptQueue.pendingEndpoints = make(map[*endpoint]struct{})
+ }
+ if e.acceptQueue.capacity == 0 {
+ e.acceptQueue.capacity = backlog
}
e.acceptMu.Unlock()
@@ -2571,8 +2580,8 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.
// Get the new accepted endpoint.
var n *endpoint
e.acceptMu.Lock()
- if element := e.accepted.endpoints.Front(); element != nil {
- n = e.accepted.endpoints.Remove(element).(*endpoint)
+ if element := e.acceptQueue.endpoints.Front(); element != nil {
+ n = e.acceptQueue.endpoints.Remove(element).(*endpoint)
}
e.acceptMu.Unlock()
if n == nil {
@@ -2989,6 +2998,8 @@ func (e *endpoint) completeStateLocked() stack.TCPEndpointState {
}
s.Sender.RACKState = e.snd.rc.TCPRACKState
+ s.Sender.RetransmitTS = e.snd.retransmitTS
+ s.Sender.SpuriousRecovery = e.snd.spuriousRecovery
return s
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index f2e8b3840..94072a115 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -100,7 +100,7 @@ func (e *endpoint) beforeSave() {
}
// saveEndpoints is invoked by stateify.
-func (a *accepted) saveEndpoints() []*endpoint {
+func (a *acceptQueue) saveEndpoints() []*endpoint {
acceptedEndpoints := make([]*endpoint, a.endpoints.Len())
for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() {
acceptedEndpoints[i] = e.Value.(*endpoint)
@@ -109,7 +109,7 @@ func (a *accepted) saveEndpoints() []*endpoint {
}
// loadEndpoints is invoked by stateify.
-func (a *accepted) loadEndpoints(acceptedEndpoints []*endpoint) {
+func (a *acceptQueue) loadEndpoints(acceptedEndpoints []*endpoint) {
for _, ep := range acceptedEndpoints {
a.endpoints.PushBack(ep)
}
@@ -251,7 +251,9 @@ func (e *endpoint) Resume(s *stack.Stack) {
go func() {
connectedLoading.Wait()
bind()
- backlog := e.accepted.cap
+ e.acceptMu.Lock()
+ backlog := e.acceptQueue.capacity
+ e.acceptMu.Unlock()
if err := e.Listen(backlog); err != nil {
panic("endpoint listening failed: " + err.String())
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index e4410ad93..f122ea009 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -66,6 +66,18 @@ const (
// DefaultSynRetries is the default value for the number of SYN retransmits
// before a connect is aborted.
DefaultSynRetries = 6
+
+ // DefaultKeepaliveIdle is the idle time for a connection before keep-alive
+ // probes are sent.
+ DefaultKeepaliveIdle = 2 * time.Hour
+
+ // DefaultKeepaliveInterval is the time between two successive keep-alive
+ // probes.
+ DefaultKeepaliveInterval = 75 * time.Second
+
+ // DefaultKeepaliveCount is the number of keep-alive probes that are sent
+ // before declaring the connection dead.
+ DefaultKeepaliveCount = 9
)
const (
diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go
index 8a026ec46..e47a07030 100644
--- a/pkg/tcpip/transport/tcp/rcv_test.go
+++ b/pkg/tcpip/transport/tcp/rcv_test.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package rcv_test
+package tcp_test
import (
"testing"
diff --git a/pkg/tcpip/transport/tcp/segment_test.go b/pkg/tcpip/transport/tcp/segment_test.go
index 2e6ea06f5..2d5fdda19 100644
--- a/pkg/tcpip/transport/tcp/segment_test.go
+++ b/pkg/tcpip/transport/tcp/segment_test.go
@@ -34,7 +34,7 @@ func checkSegmentSize(t *testing.T, name string, seg *segment, want segmentSizeW
DataSize: seg.data.Size(),
SegMemSize: seg.segMemSize(),
}
- if diff := cmp.Diff(got, want); diff != "" {
+ if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("%s differs (-want +got):\n%s", name, diff)
}
}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 2fabf1594..4377f07a0 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -144,6 +144,15 @@ type sender struct {
// probeTimer and probeWaker are used to schedule PTO for RACK TLP algorithm.
probeTimer timer `state:"nosave"`
probeWaker sleep.Waker `state:"nosave"`
+
+ // spuriousRecovery indicates whether the sender entered recovery
+ // spuriously as described in RFC3522 Section 3.2.
+ spuriousRecovery bool
+
+ // retransmitTS is the timestamp at which the sender sends retransmitted
+ // segment after entering an RTO for the first time as described in
+ // RFC3522 Section 3.2.
+ retransmitTS uint32
}
// rtt is a synchronization wrapper used to appease stateify. See the comment
@@ -425,6 +434,13 @@ func (s *sender) retransmitTimerExpired() bool {
return true
}
+ // Initialize the variables used to detect spurious recovery after
+ // entering RTO.
+ //
+ // See: https://www.rfc-editor.org/rfc/rfc3522.html#section-3.2 Step 1.
+ s.spuriousRecovery = false
+ s.retransmitTS = 0
+
// TODO(b/147297758): Band-aid fix, retransmitTimer can fire in some edge cases
// when writeList is empty. Remove this once we have a proper fix for this
// issue.
@@ -495,6 +511,10 @@ func (s *sender) retransmitTimerExpired() bool {
s.leaveRecovery()
}
+ // Record retransmitTS if the sender is not in recovery as per:
+ // https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2
+ s.recordRetransmitTS()
+
s.state = tcpip.RTORecovery
s.cc.HandleRTOExpired()
@@ -958,6 +978,13 @@ func (s *sender) sendData() {
}
func (s *sender) enterRecovery() {
+ // Initialize the variables used to detect spurious recovery after
+ // entering recovery.
+ //
+ // See: https://www.rfc-editor.org/rfc/rfc3522.html#section-3.2 Step 1.
+ s.spuriousRecovery = false
+ s.retransmitTS = 0
+
s.FastRecovery.Active = true
// Save state to reflect we're now in fast recovery.
//
@@ -972,6 +999,11 @@ func (s *sender) enterRecovery() {
s.FastRecovery.MaxCwnd = s.SndCwnd + s.Outstanding
s.FastRecovery.HighRxt = s.SndUna
s.FastRecovery.RescueRxt = s.SndUna
+
+ // Record retransmitTS if the sender is not in recovery as per:
+ // https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2
+ s.recordRetransmitTS()
+
if s.ep.SACKPermitted {
s.state = tcpip.SACKRecovery
s.ep.stack.Stats().TCP.SACKRecovery.Increment()
@@ -1147,13 +1179,15 @@ func (s *sender) isDupAck(seg *segment) bool {
// Iterate the writeList and update RACK for each segment which is newly acked
// either cumulatively or selectively. Loop through the segments which are
// sacked, and update the RACK related variables and check for reordering.
+// Returns true when the DSACK block has been detected in the received ACK.
//
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
// steps 2 and 3.
-func (s *sender) walkSACK(rcvdSeg *segment) {
+func (s *sender) walkSACK(rcvdSeg *segment) bool {
s.rc.setDSACKSeen(false)
// Look for DSACK block.
+ hasDSACK := false
idx := 0
n := len(rcvdSeg.parsedOptions.SACKBlocks)
if checkDSACK(rcvdSeg) {
@@ -1167,10 +1201,11 @@ func (s *sender) walkSACK(rcvdSeg *segment) {
s.rc.setDSACKSeen(true)
idx = 1
n--
+ hasDSACK = true
}
if n == 0 {
- return
+ return hasDSACK
}
// Sort the SACK blocks. The first block is the most recent unacked
@@ -1193,6 +1228,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) {
seg = seg.Next()
}
}
+ return hasDSACK
}
// checkDSACK checks if a DSACK is reported.
@@ -1239,6 +1275,85 @@ func checkDSACK(rcvdSeg *segment) bool {
return false
}
+func (s *sender) recordRetransmitTS() {
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2
+ //
+ // The Eifel detection algorithm is used, only upon initiation of loss
+ // recovery, i.e., when either the timeout-based retransmit or the fast
+ // retransmit is sent. The Eifel detection algorithm MUST NOT be
+ // reinitiated after loss recovery has already started. In particular,
+ // it must not be reinitiated upon subsequent timeouts for the same
+ // segment, and not upon retransmitting segments other than the oldest
+ // outstanding segment, e.g., during selective loss recovery.
+ if s.inRecovery() {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 2
+ //
+ // Set a "RetransmitTS" variable to the value of the Timestamp Value
+ // field of the Timestamps option included in the retransmit sent when
+ // loss recovery is initiated. A TCP sender must ensure that
+ // RetransmitTS does not get overwritten as loss recovery progresses,
+ // e.g., in case of a second timeout and subsequent second retransmit of
+ // the same octet.
+ s.retransmitTS = s.ep.tsValNow()
+}
+
+func (s *sender) detectSpuriousRecovery(hasDSACK bool, tsEchoReply uint32) {
+ // Return if the sender has already detected spurious recovery.
+ if s.spuriousRecovery {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 4
+ //
+ // If the value of the Timestamp Echo Reply field of the acceptable ACK's
+ // Timestamps option is smaller than the value of RetransmitTS, then
+ // proceed to next step, else return.
+ if tsEchoReply >= s.retransmitTS {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 5
+ //
+ // If the acceptable ACK carries a DSACK option [RFC2883], then return.
+ if hasDSACK {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 5
+ //
+ // If during the lifetime of the TCP connection the TCP sender has
+ // previously received an ACK with a DSACK option, or the acceptable ACK
+ // does not acknowledge all outstanding data, then proceed to next step,
+ // else return.
+ numDSACK := s.ep.stack.Stats().TCP.SegmentsAckedWithDSACK.Value()
+ if numDSACK == 0 && s.SndUna == s.SndNxt {
+ return
+ }
+
+ // See: https://datatracker.ietf.org/doc/html/rfc3522#section-3.2 Step 6
+ //
+ // If the loss recovery has been initiated with a timeout-based
+ // retransmit, then set
+ // SpuriousRecovery <- SPUR_TO (equal 1),
+ // else set
+ // SpuriousRecovery <- dupacks+1
+ // Set the spurious recovery variable to true as we do not differentiate
+ // between fast, SACK or RTO recovery.
+ s.spuriousRecovery = true
+ s.ep.stack.Stats().TCP.SpuriousRecovery.Increment()
+}
+
+// Check if the sender is in RTORecovery, FastRecovery or SACKRecovery state.
+func (s *sender) inRecovery() bool {
+ if s.state == tcpip.RTORecovery || s.state == tcpip.FastRecovery || s.state == tcpip.SACKRecovery {
+ return true
+ }
+ return false
+}
+
// handleRcvdSegment is called when a segment is received; it is responsible for
// updating the send-related state.
func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
@@ -1254,6 +1369,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
// Insert SACKBlock information into our scoreboard.
+ hasDSACK := false
if s.ep.SACKPermitted {
for _, sb := range rcvdSeg.parsedOptions.SACKBlocks {
// Only insert the SACK block if the following holds
@@ -1288,7 +1404,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// RACK.fack, then the corresponding packet has been
// reordered and RACK.reord is set to TRUE.
if s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 {
- s.walkSACK(rcvdSeg)
+ hasDSACK = s.walkSACK(rcvdSeg)
}
s.SetPipe()
}
@@ -1418,6 +1534,11 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// Clear SACK information for all acked data.
s.ep.scoreboard.Delete(s.SndUna)
+ // Detect if the sender entered recovery spuriously.
+ if s.inRecovery() {
+ s.detectSpuriousRecovery(hasDSACK, rcvdSeg.parsedOptions.TSEcr)
+ }
+
// If we are not in fast recovery then update the congestion
// window based on the number of acknowledged packets.
if !s.FastRecovery.Active {
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
index c35db7c95..0d36d0dd0 100644
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -1059,16 +1059,17 @@ func TestRACKWithWindowFull(t *testing.T) {
for i := 0; i < numPkts; i++ {
c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
bytesRead += maxPayload
- if i == 0 {
- // Send ACK for the first packet to establish RTT.
- c.SendAck(seq, maxPayload)
- }
}
- // SACK for #10 packet.
- start := c.IRS.Add(seqnum.Size(1 + (numPkts-1)*maxPayload))
+ // Expect retransmission of last packet due to TLP.
+ c.ReceiveAndCheckPacketWithOptions(data, (numPkts-1)*maxPayload, maxPayload, tsOptionSize)
+
+ // SACK for first and last packet.
+ start := c.IRS.Add(seqnum.Size(maxPayload))
end := start.Add(seqnum.Size(maxPayload))
- c.SendAckWithSACK(seq, 2*maxPayload, []header.SACKBlock{{start, end}})
+ dsackStart := c.IRS.Add(seqnum.Size(1 + (numPkts-1)*maxPayload))
+ dsackEnd := dsackStart.Add(seqnum.Size(maxPayload))
+ c.SendAckWithSACK(seq, 2*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}})
var info tcpip.TCPInfoOption
if err := c.EP.GetSockOpt(&info); err != nil {
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index 6255355bb..896249d2d 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -23,6 +23,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -702,3 +703,257 @@ func TestRecoveryEntry(t *testing.T) {
t.Error(err)
}
}
+
+func verifySpuriousRecoveryMetric(t *testing.T, c *context.Context, numSpuriousRecovery uint64) {
+ t.Helper()
+
+ metricPollFn := func() error {
+ tcpStats := c.Stack().Stats().TCP
+ stats := []struct {
+ stat *tcpip.StatCounter
+ name string
+ want uint64
+ }{
+ {tcpStats.SpuriousRecovery, "stats.TCP.SpuriousRecovery", numSpuriousRecovery},
+ }
+ for _, s := range stats {
+ if got, want := s.stat.Value(), s.want; got != want {
+ return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
+ }
+ }
+ return nil
+ }
+
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
+ }
+}
+
+func checkReceivedPacket(t *testing.T, c *context.Context, tcpHdr header.TCP, bytesRead uint32, b, data []byte) {
+ payloadLen := uint32(len(tcpHdr.Payload()))
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPSeqNum(uint32(c.IRS)+1+bytesRead),
+ checker.TCPAckNum(context.TestInitialSequenceNumber+1),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
+ ),
+ )
+ pdata := data[bytesRead : bytesRead+payloadLen]
+ if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) {
+ t.Fatalf("got data = %v, want = %v", p, pdata)
+ }
+}
+
+func buildTSOptionFromHeader(tcpHdr header.TCP) []byte {
+ parsedOpts := tcpHdr.ParsedOptions()
+ tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
+ header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:])
+ return tsOpt[:]
+}
+
+func TestDetectSpuriousRecoveryWithRTO(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan struct{})
+ c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) {
+ if s.Sender.RetransmitTS == 0 {
+ t.Fatalf("RetransmitTS did not get updated, got: 0 want > 0")
+ }
+ if !s.Sender.SpuriousRecovery {
+ t.Fatalf("Spurious recovery was not detected")
+ }
+ close(probeDone)
+ })
+
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+ numPackets := 5
+ data := make([]byte, numPackets*maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+ // Write the data.
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ var options []byte
+ var bytesRead uint32
+ for i := 0; i < numPackets; i++ {
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ checkReceivedPacket(t, c, tcpHdr, bytesRead, b, data)
+
+ // Get options only for the first packet. This will be sent with
+ // the ACK to indicate the acknowledgement is for the original
+ // packet.
+ if i == 0 && c.TimeStampEnabled {
+ options = buildTSOptionFromHeader(tcpHdr)
+ }
+ bytesRead += uint32(len(tcpHdr.Payload()))
+ }
+
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ // Expect #5 segment with TLP.
+ c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize)
+
+ // Expect #1 segment because of RTO.
+ c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize)
+
+ info := tcpip.TCPInfoOption{}
+ if err := c.EP.GetSockOpt(&info); err != nil {
+ t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err)
+ }
+
+ if info.CcState != tcpip.RTORecovery {
+ t.Fatalf("Loss recovery did not happen, got: %v want: %v", info.CcState, tcpip.RTORecovery)
+ }
+
+ // Acknowledge the data.
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seq,
+ AckNum: c.IRS.Add(1 + seqnum.Size(maxPayload)),
+ RcvWnd: rcvWnd,
+ TCPOpts: options,
+ })
+
+ // Wait for the probe function to finish processing the
+ // ACK before the test completes.
+ <-probeDone
+
+ verifySpuriousRecoveryMetric(t, c, 1 /* numSpuriousRecovery */)
+}
+
+func TestSACKDetectSpuriousRecoveryWithDupACK(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ numAck := 0
+ probeDone := make(chan struct{})
+ c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) {
+ if numAck < 3 {
+ numAck++
+ return
+ }
+
+ if s.Sender.RetransmitTS == 0 {
+ t.Fatalf("RetransmitTS did not get updated, got: 0 want > 0")
+ }
+ if !s.Sender.SpuriousRecovery {
+ t.Fatalf("Spurious recovery was not detected")
+ }
+ close(probeDone)
+ })
+
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+ numPackets := 5
+ data := make([]byte, numPackets*maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+ // Write the data.
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ var options []byte
+ var bytesRead uint32
+ for i := 0; i < numPackets; i++ {
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ checkReceivedPacket(t, c, tcpHdr, bytesRead, b, data)
+
+ // Get options only for the first packet. This will be sent with
+ // the ACK to indicate the acknowledgement is for the original
+ // packet.
+ if i == 0 && c.TimeStampEnabled {
+ options = buildTSOptionFromHeader(tcpHdr)
+ }
+ bytesRead += uint32(len(tcpHdr.Payload()))
+ }
+
+ // Receive the retransmitted packet after TLP.
+ c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize)
+
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ // Send ACK for #3 and #4 segments to avoid entering TLP.
+ start := c.IRS.Add(3*maxPayload + 1)
+ end := start.Add(2 * maxPayload)
+ c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}})
+
+ c.SendAck(seq, 0 /* bytesReceived */)
+ c.SendAck(seq, 0 /* bytesReceived */)
+
+ // Receive the retransmitted packet after three duplicate ACKs.
+ c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize)
+
+ info := tcpip.TCPInfoOption{}
+ if err := c.EP.GetSockOpt(&info); err != nil {
+ t.Fatalf("c.EP.GetSockOpt(&%T) = %s", info, err)
+ }
+
+ if info.CcState != tcpip.SACKRecovery {
+ t.Fatalf("Loss recovery did not happen, got: %v want: %v", info.CcState, tcpip.SACKRecovery)
+ }
+
+ // Acknowledge the data.
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seq,
+ AckNum: c.IRS.Add(1 + seqnum.Size(maxPayload)),
+ RcvWnd: rcvWnd,
+ TCPOpts: options,
+ })
+
+ // Wait for the probe function to finish processing the
+ // ACK before the test completes.
+ <-probeDone
+
+ verifySpuriousRecoveryMetric(t, c, 1 /* numSpuriousRecovery */)
+}
+
+func TestNoSpuriousRecoveryWithDSACK(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+ numPackets := 5
+ data := sendAndReceiveWithSACK(t, c, numPackets, true /* enableRACK */)
+
+ // Receive the retransmitted packet after TLP.
+ c.ReceiveAndCheckPacketWithOptions(data, 4*maxPayload, maxPayload, tsOptionSize)
+
+ // Send ACK for #3 and #4 segments to avoid entering TLP.
+ start := c.IRS.Add(3*maxPayload + 1)
+ end := start.Add(2 * maxPayload)
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}})
+
+ c.SendAck(seq, 0 /* bytesReceived */)
+ c.SendAck(seq, 0 /* bytesReceived */)
+
+ // Receive the retransmitted packet after three duplicate ACKs.
+ c.ReceiveAndCheckPacketWithOptions(data, 0, maxPayload, tsOptionSize)
+
+ // Acknowledge the data with DSACK for #1 segment.
+ start = c.IRS.Add(maxPayload + 1)
+ end = start.Add(2 * maxPayload)
+ seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ c.SendAckWithSACK(seq, 6*maxPayload, []header.SACKBlock{{start, end}})
+
+ verifySpuriousRecoveryMetric(t, c, 0 /* numSpuriousRecovery */)
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index bc8708a5b..6f1ee3816 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -1382,8 +1382,12 @@ func TestListenerReadinessOnEvent(t *testing.T) {
if err := s.CreateNIC(id, ep); err != nil {
t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err)
}
- if err := s.AddAddress(id, ipv4.ProtocolNumber, context.StackAddr); err != nil {
- t.Fatalf("AddAddress(%d, ipv4.ProtocolNumber, %s): %s", id, context.StackAddr, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(context.StackAddr).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", id, protocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{
{Destination: header.IPv4EmptySubnet, NIC: id},
@@ -1652,6 +1656,71 @@ func TestConnectBindToDevice(t *testing.T) {
}
}
+func TestShutdownConnectingSocket(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ shutdownMode tcpip.ShutdownFlags
+ }{
+ {"ShutdownRead", tcpip.ShutdownRead},
+ {"ShutdownWrite", tcpip.ShutdownWrite},
+ {"ShutdownReadWrite", tcpip.ShutdownRead | tcpip.ShutdownWrite},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create an endpoint, don't handshake because we want to interfere with
+ // the handshake process.
+ c.Create(-1)
+
+ waitEntry, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ // Start connection attempt.
+ addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, c.EP.Connect(addr)); d != "" {
+ t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d)
+ }
+
+ // Check the SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
+
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+
+ if err := c.EP.Shutdown(test.shutdownMode); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
+
+ // The endpoint internal state is updated immediately.
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+
+ select {
+ case <-ch:
+ default:
+ t.Fatal("endpoint was not notified")
+ }
+
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, &tcpip.ErrConnectionReset{})
+
+ // If the endpoint is not properly shutdown, it'll re-attempt to connect
+ // by sending another ACK packet.
+ c.CheckNoPacketTimeout("got an unexpected packet", tcp.InitialRTO+(500*time.Millisecond))
+ })
+ }
+}
+
func TestSynSent(t *testing.T) {
for _, test := range []struct {
name string
@@ -1675,7 +1744,7 @@ func TestSynSent(t *testing.T) {
addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
err := c.EP.Connect(addr)
- if d := cmp.Diff(err, &tcpip.ErrConnectStarted{}); d != "" {
+ if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" {
t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d)
}
@@ -1991,7 +2060,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
)
// Cause a FIN to be generated.
- c.EP.Shutdown(tcpip.ShutdownWrite)
+ if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
// Make sure we get the FIN but DON't ACK IT.
checker.IPv4(t, c.GetPacket(),
@@ -2007,7 +2078,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
// Cause a RST to be generated by closing the read end now since we have
// unread data.
- c.EP.Shutdown(tcpip.ShutdownRead)
+ if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
+ t.Fatalf("Shutdown failed: %s", err)
+ }
// Make sure we get the RST
checker.IPv4(t, c.GetPacket(),
@@ -2145,12 +2218,15 @@ func TestSmallReceiveBufferReadiness(t *testing.T) {
t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %s", nicOpts, err)
}
- addr := tcpip.AddressWithPrefix{
- Address: tcpip.Address("\x7f\x00\x00\x01"),
- PrefixLen: 8,
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address("\x7f\x00\x00\x01"),
+ PrefixLen: 8,
+ },
}
- if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddressWithPrefix(_, _, %s) failed: %s", addr, err)
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}) failed: %s", nicID, protocolAddr, err)
}
{
@@ -4954,13 +5030,17 @@ func makeStack() (*stack.Stack, tcpip.Error) {
}
for _, ct := range []struct {
- number tcpip.NetworkProtocolNumber
- address tcpip.Address
+ number tcpip.NetworkProtocolNumber
+ addrWithPrefix tcpip.AddressWithPrefix
}{
- {ipv4.ProtocolNumber, context.StackAddr},
- {ipv6.ProtocolNumber, context.StackV6Addr},
+ {ipv4.ProtocolNumber, context.StackAddrWithPrefix},
+ {ipv6.ProtocolNumber, context.StackV6AddrWithPrefix},
} {
- if err := s.AddAddress(1, ct.number, ct.address); err != nil {
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ct.number,
+ AddressWithPrefix: ct.addrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil {
return nil, err
}
}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 6e55a7a32..88bb99354 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -243,8 +243,8 @@ func NewWithOpts(t *testing.T, opts Options) *Context {
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: StackAddrWithPrefix,
}
- if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err)
+ if err := s.AddProtocolAddress(1, v4ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v4ProtocolAddr, err)
}
routeTable = append(routeTable, tcpip.Route{
Destination: header.IPv4EmptySubnet,
@@ -257,8 +257,8 @@ func NewWithOpts(t *testing.T, opts Options) *Context {
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: StackV6AddrWithPrefix,
}
- if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err)
+ if err := s.AddProtocolAddress(1, v6ProtocolAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", v6ProtocolAddr, err)
}
routeTable = append(routeTable, tcpip.Route{
Destination: header.IPv6EmptySubnet,
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index 5cc7a2886..d2c0963b0 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -63,5 +63,6 @@ go_test(
"//pkg/tcpip/transport/icmp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
+ "@org_golang_x_time//rate:go_default_library",
],
)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 4255457f9..077a2325a 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -60,9 +60,8 @@ type endpoint struct {
waiterQueue *waiter.Queue
uniqueID uint64
net network.Endpoint
- // TODO(b/142022063): Add ability to save and restore per endpoint stats.
- stats tcpip.TransportEndpointStats `state:"nosave"`
- ops tcpip.SocketOptions
+ stats tcpip.TransportEndpointStats
+ ops tcpip.SocketOptions
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -234,7 +233,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// Control Messages
cm := tcpip.ControlMessages{
HasTimestamp: true,
- Timestamp: p.receivedAt.UnixNano(),
+ Timestamp: p.receivedAt,
}
switch p.netProto {
@@ -243,19 +242,29 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
cm.HasTOS = true
cm.TOS = p.tos
}
+
+ if e.ops.GetReceivePacketInfo() {
+ cm.HasIPPacketInfo = true
+ cm.PacketInfo = p.packetInfo
+ }
case header.IPv6ProtocolNumber:
if e.ops.GetReceiveTClass() {
cm.HasTClass = true
// Although TClass is an 8-bit value it's read in the CMsg as a uint32.
cm.TClass = uint32(p.tos)
}
+
+ if e.ops.GetIPv6ReceivePacketInfo() {
+ cm.HasIPv6PacketInfo = true
+ cm.IPv6PacketInfo = tcpip.IPv6PacketInfo{
+ NIC: p.packetInfo.NIC,
+ Addr: p.packetInfo.DestinationAddr,
+ }
+ }
default:
panic(fmt.Sprintf("unrecognized network protocol = %d", p.netProto))
}
- if e.ops.GetReceivePacketInfo() {
- cm.HasIPPacketInfo = true
- cm.PacketInfo = p.packetInfo
- }
+
if e.ops.GetReceiveOriginalDstAddress() {
cm.HasOriginalDstAddress = true
cm.OriginalDstAddress = p.destinationAddress
@@ -283,7 +292,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
-// +checklocks:e.mu
+// +checklocksread:e.mu
func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
switch e.net.State() {
case transport.DatagramEndpointStateInitial:
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 554ce1de4..b3199489c 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -22,6 +22,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
+ "golang.org/x/time/rate"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -313,6 +314,9 @@ func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal boo
Clock: &faketime.NullClock{},
}
s := stack.New(options)
+ // Disable ICMP rate limiter because we're using Null clock, which never advances time and thus
+ // never allows ICMP messages.
+ s.SetICMPLimit(rate.Inf)
ep := channel.New(256, mtu, "")
wep := stack.LinkEndpoint(ep)
@@ -323,12 +327,20 @@ func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal boo
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, stackAddr, err)
+ protocolAddrV4 := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(stackAddr).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err)
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
- t.Fatalf("AddAddress((%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, stackV6Addr, err)
+ protocolAddrV6 := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.Address(stackV6Addr).WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err)
}
s.SetRouteTable([]tcpip.Route{
@@ -1357,64 +1369,70 @@ func TestReadIncrementsPacketsReceived(t *testing.T) {
func TestReadIPPacketInfo(t *testing.T) {
tests := []struct {
- name string
- proto tcpip.NetworkProtocolNumber
- flow testFlow
- expectedLocalAddr tcpip.Address
- expectedDestAddr tcpip.Address
+ name string
+ proto tcpip.NetworkProtocolNumber
+ flow testFlow
+ checker func(tcpip.NICID) checker.ControlMessagesChecker
}{
{
- name: "IPv4 unicast",
- proto: header.IPv4ProtocolNumber,
- flow: unicastV4,
- expectedLocalAddr: stackAddr,
- expectedDestAddr: stackAddr,
+ name: "IPv4 unicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: unicastV4,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
+ NIC: id,
+ LocalAddr: stackAddr,
+ DestinationAddr: stackAddr,
+ })
+ },
},
{
name: "IPv4 multicast",
proto: header.IPv4ProtocolNumber,
flow: multicastV4,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedLocalAddr: multicastAddr,
- expectedDestAddr: multicastAddr,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
+ NIC: id,
+ // TODO(gvisor.dev/issue/3556): Check for a unicast address.
+ LocalAddr: multicastAddr,
+ DestinationAddr: multicastAddr,
+ })
+ },
},
{
name: "IPv4 broadcast",
proto: header.IPv4ProtocolNumber,
flow: broadcast,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedLocalAddr: broadcastAddr,
- expectedDestAddr: broadcastAddr,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
+ NIC: id,
+ // TODO(gvisor.dev/issue/3556): Check for a unicast address.
+ LocalAddr: broadcastAddr,
+ DestinationAddr: broadcastAddr,
+ })
+ },
},
{
- name: "IPv6 unicast",
- proto: header.IPv6ProtocolNumber,
- flow: unicastV6,
- expectedLocalAddr: stackV6Addr,
- expectedDestAddr: stackV6Addr,
+ name: "IPv6 unicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: unicastV6,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{
+ NIC: id,
+ Addr: stackV6Addr,
+ })
+ },
},
{
name: "IPv6 multicast",
proto: header.IPv6ProtocolNumber,
flow: multicastV6,
- // This should actually be a unicast address assigned to the interface.
- //
- // TODO(gvisor.dev/issue/3556): This check is validating incorrect
- // behaviour. We still include the test so that once the bug is
- // resolved, this test will start to fail and the individual tasked
- // with fixing this bug knows to also fix this test :).
- expectedLocalAddr: multicastV6Addr,
- expectedDestAddr: multicastV6Addr,
+ checker: func(id tcpip.NICID) checker.ControlMessagesChecker {
+ return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{
+ NIC: id,
+ Addr: multicastV6Addr,
+ })
+ },
},
}
@@ -1437,13 +1455,16 @@ func TestReadIPPacketInfo(t *testing.T) {
}
}
- c.ep.SocketOptions().SetReceivePacketInfo(true)
+ switch f := test.flow.netProto(); f {
+ case header.IPv4ProtocolNumber:
+ c.ep.SocketOptions().SetReceivePacketInfo(true)
+ case header.IPv6ProtocolNumber:
+ c.ep.SocketOptions().SetIPv6ReceivePacketInfo(true)
+ default:
+ t.Fatalf("unhandled protocol number = %d", f)
+ }
- testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
- NIC: 1,
- LocalAddr: test.expectedLocalAddr,
- DestinationAddr: test.expectedDestAddr,
- }))
+ testRead(c, test.flow, test.checker(c.nicID))
if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 {
t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)
@@ -2504,8 +2525,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
if err := s.CreateNIC(nicID1, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
}
- if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
+ if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err)
}
s.SetRouteTable(test.routes)
diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD
index 234125c38..8902be2d3 100644
--- a/pkg/unet/BUILD
+++ b/pkg/unet/BUILD
@@ -10,6 +10,7 @@ go_library(
],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/eventfd",
"//pkg/sync",
"@org_golang_x_sys//unix:go_default_library",
],
diff --git a/pkg/unet/unet.go b/pkg/unet/unet.go
index 40fa72925..0dc0c37bd 100644
--- a/pkg/unet/unet.go
+++ b/pkg/unet/unet.go
@@ -23,6 +23,7 @@ import (
"sync/atomic"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/eventfd"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -55,15 +56,6 @@ func socket(packet bool) (int, error) {
return fd, nil
}
-// eventFD returns a new event FD with initial value 0.
-func eventFD() (int, error) {
- f, _, e := unix.Syscall(unix.SYS_EVENTFD2, 0, 0, 0)
- if e != 0 {
- return -1, e
- }
- return int(f), nil
-}
-
// Socket is a connected unix domain socket.
type Socket struct {
// gate protects use of fd.
@@ -78,7 +70,7 @@ type Socket struct {
// efd is an event FD that is signaled when the socket is closing.
//
// efd is immutable and remains valid until Close/Release.
- efd int
+ efd eventfd.Eventfd
// race is an atomic variable used to avoid triggering the race
// detector. See comment in SocketPair below.
@@ -95,7 +87,7 @@ func NewSocket(fd int) (*Socket, error) {
return nil, err
}
- efd, err := eventFD()
+ efd, err := eventfd.Create()
if err != nil {
return nil, err
}
@@ -110,16 +102,14 @@ func NewSocket(fd int) (*Socket, error) {
// closing the event FD.
func (s *Socket) finish() error {
// Signal any blocked or future polls.
- //
- // N.B. eventfd writes must be 8 bytes.
- if _, err := unix.Write(s.efd, []byte{1, 0, 0, 0, 0, 0, 0, 0}); err != nil {
+ if err := s.efd.Notify(); err != nil {
return err
}
// Close the gate, blocking until all FD users leave.
s.gate.Close()
- return unix.Close(s.efd)
+ return s.efd.Close()
}
// Close closes the socket.
diff --git a/pkg/unet/unet_unsafe.go b/pkg/unet/unet_unsafe.go
index f0bf93ddd..ea281fec3 100644
--- a/pkg/unet/unet_unsafe.go
+++ b/pkg/unet/unet_unsafe.go
@@ -43,7 +43,7 @@ func (s *Socket) wait(write bool) error {
},
{
// The eventfd, signaled when we are closing.
- Fd: int32(s.efd),
+ Fd: int32(s.efd.FD()),
Events: unix.POLLIN,
},
}
diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD
index 7028a465f..19328f89c 100644
--- a/runsc/boot/BUILD
+++ b/runsc/boot/BUILD
@@ -15,6 +15,7 @@ go_library(
"limits.go",
"loader.go",
"network.go",
+ "profile.go",
"strace.go",
"vfs.go",
],
@@ -80,7 +81,6 @@ go_library(
"//pkg/sentry/loader",
"//pkg/sentry/pgalloc",
"//pkg/sentry/platform",
- "//pkg/sentry/sighandling",
"//pkg/sentry/socket/hostinet",
"//pkg/sentry/socket/netfilter",
"//pkg/sentry/socket/netlink",
@@ -96,6 +96,7 @@ go_library(
"//pkg/sentry/usage",
"//pkg/sentry/vfs",
"//pkg/sentry/watchdog",
+ "//pkg/sighandling",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/link/ethernet",
diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go
index 703f34827..4a9d30c5f 100644
--- a/runsc/boot/filter/config.go
+++ b/runsc/boot/filter/config.go
@@ -304,6 +304,22 @@ var allowedSyscalls = seccomp.SyscallRules{
seccomp.EqualTo(unix.SPLICE_F_NONBLOCK), /* flags */
},
},
+ unix.SYS_TIMER_CREATE: []seccomp.Rule{
+ {
+ seccomp.EqualTo(unix.CLOCK_THREAD_CPUTIME_ID), /* which */
+ seccomp.MatchAny{}, /* sevp */
+ seccomp.MatchAny{}, /* timerid */
+ },
+ },
+ unix.SYS_TIMER_DELETE: []seccomp.Rule{},
+ unix.SYS_TIMER_SETTIME: []seccomp.Rule{
+ {
+ seccomp.MatchAny{}, /* timerid */
+ seccomp.EqualTo(0), /* flags */
+ seccomp.MatchAny{}, /* new_value */
+ seccomp.EqualTo(0), /* old_value */
+ },
+ },
unix.SYS_TGKILL: []seccomp.Rule{
{
seccomp.EqualTo(uint64(os.Getpid())),
@@ -630,7 +646,7 @@ func hostInetFilters() seccomp.SyscallRules {
func controlServerFilters(fd int) seccomp.SyscallRules {
return seccomp.SyscallRules{
- unix.SYS_ACCEPT: []seccomp.Rule{
+ unix.SYS_ACCEPT4: []seccomp.Rule{
{
seccomp.EqualTo(fd),
},
diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go
index 40cf2a3df..ceeccf483 100644
--- a/runsc/boot/fs.go
+++ b/runsc/boot/fs.go
@@ -188,8 +188,8 @@ func compileMounts(spec *specs.Spec, conf *config.Config, vfs2Enabled bool) []sp
return mounts
}
-// p9MountData creates a slice of p9 mount data.
-func p9MountData(fd int, fa config.FileAccessType, vfs2 bool) []string {
+// goferMountData creates a slice of gofer mount data.
+func goferMountData(fd int, fa config.FileAccessType, attachPath string, vfs2 bool, lisafs bool) []string {
opts := []string{
"trans=fd",
"rfdno=" + strconv.Itoa(fd),
@@ -203,6 +203,10 @@ func p9MountData(fd int, fa config.FileAccessType, vfs2 bool) []string {
if fa == config.FileAccessShared {
opts = append(opts, "cache=remote_revalidating")
}
+ if vfs2 && lisafs {
+ opts = append(opts, "lisafs=true")
+ opts = append(opts, "aname="+attachPath)
+ }
return opts
}
@@ -761,7 +765,7 @@ func (c *containerMounter) createRootMount(ctx context.Context, conf *config.Con
fd := c.fds.remove()
log.Infof("Mounting root over 9P, ioFD: %d", fd)
p9FS := mustFindFilesystem("9p")
- opts := p9MountData(fd, conf.FileAccess, false /* vfs2 */)
+ opts := goferMountData(fd, conf.FileAccess, "/", false /* vfs2 */, false /* lisafs */)
// We can't check for overlayfs here because sandbox is chroot'ed and gofer
// can only send mount options for specs.Mounts (specs.Root is missing
@@ -822,7 +826,7 @@ func (c *containerMounter) getMountNameAndOptions(conf *config.Config, m *specs.
case bind:
fd := c.fds.remove()
fsName = gofervfs2.Name
- opts = p9MountData(fd, c.getMountAccessType(conf, m), conf.VFS2)
+ opts = goferMountData(fd, c.getMountAccessType(conf, m), m.Destination, conf.VFS2, conf.Lisafs)
// If configured, add overlay to all writable mounts.
useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly
case cgroupfs.Name:
@@ -980,7 +984,7 @@ func (c *containerMounter) createRestoreEnvironment(conf *config.Config) (*fs.Re
// Add root mount.
fd := c.fds.remove()
- opts := p9MountData(fd, conf.FileAccess, false /* vfs2 */)
+ opts := goferMountData(fd, conf.FileAccess, "/", conf.VFS2, false /* lisafs */)
mf := fs.MountSourceFlags{}
if c.root.Readonly || conf.Overlay {
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index 1dd0048ac..dbb02fdf4 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -49,13 +49,13 @@ import (
"gvisor.dev/gvisor/pkg/sentry/loader"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/pkg/sentry/sighandling"
"gvisor.dev/gvisor/pkg/sentry/socket/netfilter"
"gvisor.dev/gvisor/pkg/sentry/syscalls/linux/vfs2"
"gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sentry/watchdog"
+ "gvisor.dev/gvisor/pkg/sighandling"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/link/ethernet"
@@ -120,6 +120,10 @@ type Loader struct {
// container. It should be called when a sandbox is destroyed.
stopSignalForwarding func()
+ // stopProfiling stops profiling started at container creation. It
+ // should be called when a sandbox is destroyed.
+ stopProfiling func()
+
// restore is set to true if we are restoring a container.
restore bool
@@ -199,6 +203,21 @@ type Args struct {
TotalMem uint64
// UserLogFD is the file descriptor to write user logs to.
UserLogFD int
+ // ProfileBlockFD is the file descriptor to write a block profile to.
+ // Valid if >=0.
+ ProfileBlockFD int
+ // ProfileCPUFD is the file descriptor to write a CPU profile to.
+ // Valid if >=0.
+ ProfileCPUFD int
+ // ProfileHeapFD is the file descriptor to write a heap profile to.
+ // Valid if >=0.
+ ProfileHeapFD int
+ // ProfileMutexFD is the file descriptor to write a mutex profile to.
+ // Valid if >=0.
+ ProfileMutexFD int
+ // TraceFD is the file descriptor to write a Go execution trace to.
+ // Valid if >=0.
+ TraceFD int
}
// make sure stdioFDs are always the same on initial start and on restore
@@ -207,6 +226,8 @@ const startingStdioFD = 256
// New initializes a new kernel loader configured by spec.
// New also handles setting up a kernel for restoring a container.
func New(args Args) (*Loader, error) {
+ stopProfiling := startProfiling(args)
+
// We initialize the rand package now to make sure /dev/urandom is pre-opened
// on kernels that do not support getrandom(2).
if err := rand.Init(); err != nil {
@@ -220,10 +241,8 @@ func New(args Args) (*Loader, error) {
// Is this a VFSv2 kernel?
if args.Conf.VFS2 {
kernel.VFS2Enabled = true
- if args.Conf.FUSE {
- kernel.FUSEEnabled = true
- }
-
+ kernel.FUSEEnabled = args.Conf.FUSE
+ kernel.LISAFSEnabled = args.Conf.Lisafs
vfs2.Override()
}
@@ -321,6 +340,7 @@ func New(args Args) (*Loader, error) {
if args.TotalMem > 0 {
// Adjust the total memory returned by the Sentry so that applications that
// use /proc/meminfo can make allocations based on this limit.
+ usage.MinimumTotalMemoryBytes = args.TotalMem
usage.MaximumTotalMemoryBytes = args.TotalMem
log.Infof("Setting total memory to %.2f GB", float64(args.TotalMem)/(1<<30))
}
@@ -400,12 +420,13 @@ func New(args Args) (*Loader, error) {
eid := execID{cid: args.ID}
l := &Loader{
- k: k,
- watchdog: dog,
- sandboxID: args.ID,
- processes: map[execID]*execProcess{eid: {}},
- mountHints: mountHints,
- root: info,
+ k: k,
+ watchdog: dog,
+ sandboxID: args.ID,
+ processes: map[execID]*execProcess{eid: {}},
+ mountHints: mountHints,
+ root: info,
+ stopProfiling: stopProfiling,
}
// We don't care about child signals; some platforms can generate a
@@ -498,6 +519,8 @@ func (l *Loader) Destroy() {
for _, f := range l.root.goferFDs {
_ = f.Close()
}
+
+ l.stopProfiling()
}
func createPlatform(conf *config.Config, deviceFile *os.File) (platform.Platform, error) {
@@ -927,7 +950,7 @@ func (l *Loader) executeAsync(args *control.ExecArgs) (kernel.ThreadID, error) {
if kernel.VFS2Enabled {
// task.MountNamespaceVFS2() does not take a ref, so we must do so ourselves.
args.MountNamespaceVFS2 = tg.Leader().MountNamespaceVFS2()
- if !args.MountNamespaceVFS2.TryIncRef() {
+ if args.MountNamespaceVFS2 == nil || !args.MountNamespaceVFS2.TryIncRef() {
return 0, fmt.Errorf("container %q has stopped", args.ContainerID)
}
} else {
diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go
index ac6c26d25..eb56bc563 100644
--- a/runsc/boot/loader_test.go
+++ b/runsc/boot/loader_test.go
@@ -552,7 +552,7 @@ func TestRestoreEnvironment(t *testing.T) {
{
Dev: "9pfs-/",
Flags: fs.MountSourceFlags{ReadOnly: true},
- DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true",
+ DataString: "trans=fd,rfdno=0,wfdno=0",
},
},
"tmpfs": {
@@ -606,11 +606,11 @@ func TestRestoreEnvironment(t *testing.T) {
{
Dev: "9pfs-/",
Flags: fs.MountSourceFlags{ReadOnly: true},
- DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true",
+ DataString: "trans=fd,rfdno=0,wfdno=0",
},
{
Dev: "9pfs-/dev/fd-foo",
- DataString: "trans=fd,rfdno=1,wfdno=1,privateunixsocket=true,cache=remote_revalidating",
+ DataString: "trans=fd,rfdno=1,wfdno=1,cache=remote_revalidating",
},
},
"tmpfs": {
@@ -664,7 +664,7 @@ func TestRestoreEnvironment(t *testing.T) {
{
Dev: "9pfs-/",
Flags: fs.MountSourceFlags{ReadOnly: true},
- DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true",
+ DataString: "trans=fd,rfdno=0,wfdno=0",
},
},
"tmpfs": {
@@ -704,6 +704,7 @@ func TestRestoreEnvironment(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
conf := testConfig()
+ conf.VFS2 = true
var ioFDs []*fd.FD
for _, ioFD := range tc.ioFDs {
ioFDs = append(ioFDs, fd.New(ioFD))
@@ -713,7 +714,7 @@ func TestRestoreEnvironment(t *testing.T) {
spec: tc.spec,
goferFDs: ioFDs,
}
- mntr := newContainerMounter(&info, nil, &podMountHints{}, false /* vfs2Enabled */)
+ mntr := newContainerMounter(&info, nil, &podMountHints{}, conf.VFS2)
actualRenv, err := mntr.createRestoreEnvironment(conf)
if !tc.errorExpected && err != nil {
t.Fatalf("could not create restore environment for test:%s", tc.name)
diff --git a/runsc/boot/network.go b/runsc/boot/network.go
index 2257ecca7..9fb3ebd95 100644
--- a/runsc/boot/network.go
+++ b/runsc/boot/network.go
@@ -282,12 +282,15 @@ func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, ep stack.LinkE
for _, addr := range addrs {
proto, tcpipAddr := ipToAddressAndProto(addr.Address)
- ap := tcpip.AddressWithPrefix{
- Address: tcpipAddr,
- PrefixLen: addr.PrefixLen,
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: proto,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpipAddr,
+ PrefixLen: addr.PrefixLen,
+ },
}
- if err := n.Stack.AddAddressWithPrefix(id, proto, ap); err != nil {
- return fmt.Errorf("AddAddress(%v, %v, %v) failed: %v", id, proto, tcpipAddr, err)
+ if err := n.Stack.AddProtocolAddress(id, protocolAddr, stack.AddressProperties{}); err != nil {
+ return fmt.Errorf("AddProtocolAddress(%d, %+v, {}) failed: %s", id, protocolAddr, err)
}
}
return nil
diff --git a/runsc/boot/profile.go b/runsc/boot/profile.go
new file mode 100644
index 000000000..3ecd3e532
--- /dev/null
+++ b/runsc/boot/profile.go
@@ -0,0 +1,95 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package boot
+
+import (
+ "os"
+ "runtime"
+ "runtime/pprof"
+ "runtime/trace"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/control"
+)
+
+// startProfiling initiates profiling as defined by the ProfileConfig, and
+// returns a function that should be called to stop profiling.
+func startProfiling(args Args) func() {
+ var onStopProfiling []func()
+ stopProfiling := func() {
+ for _, f := range onStopProfiling {
+ f()
+ }
+ }
+
+ if args.ProfileBlockFD >= 0 {
+ file := os.NewFile(uintptr(args.ProfileBlockFD), "profile-block")
+
+ runtime.SetBlockProfileRate(control.DefaultBlockProfileRate)
+ onStopProfiling = append(onStopProfiling, func() {
+ if err := pprof.Lookup("block").WriteTo(file, 0); err != nil {
+ log.Warningf("Error writing block profile: %v", err)
+ }
+ file.Close()
+ runtime.SetBlockProfileRate(0)
+ })
+ }
+
+ if args.ProfileCPUFD >= 0 {
+ file := os.NewFile(uintptr(args.ProfileCPUFD), "profile-cpu")
+
+ pprof.StartCPUProfile(file)
+ onStopProfiling = append(onStopProfiling, func() {
+ pprof.StopCPUProfile()
+ file.Close()
+ })
+ }
+
+ if args.ProfileHeapFD >= 0 {
+ file := os.NewFile(uintptr(args.ProfileHeapFD), "profile-heap")
+
+ onStopProfiling = append(onStopProfiling, func() {
+ if err := pprof.Lookup("heap").WriteTo(file, 0); err != nil {
+ log.Warningf("Error writing heap profile: %v", err)
+ }
+ file.Close()
+ })
+ }
+
+ if args.ProfileMutexFD >= 0 {
+ file := os.NewFile(uintptr(args.ProfileMutexFD), "profile-mutex")
+
+ prev := runtime.SetMutexProfileFraction(control.DefaultMutexProfileRate)
+ onStopProfiling = append(onStopProfiling, func() {
+ if err := pprof.Lookup("mutex").WriteTo(file, 0); err != nil {
+ log.Warningf("Error writing mutex profile: %v", err)
+ }
+ file.Close()
+ runtime.SetMutexProfileFraction(prev)
+ })
+ }
+
+ if args.TraceFD >= 0 {
+ file := os.NewFile(uintptr(args.TraceFD), "trace")
+
+ trace.Start(file)
+ onStopProfiling = append(onStopProfiling, func() {
+ trace.Stop()
+ file.Close()
+ })
+ }
+
+ return stopProfiling
+}
diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go
index 4286fb978..9f0d1ae36 100644
--- a/runsc/boot/vfs.go
+++ b/runsc/boot/vfs.go
@@ -178,7 +178,7 @@ func (c *containerMounter) mountAll(conf *config.Config, procArgs *kernel.Create
rootProcArgs.Credentials = rootCreds
rootProcArgs.Umask = 0022
rootProcArgs.MaxSymlinkTraversals = linux.MaxSymlinkTraversals
- rootCtx := procArgs.NewContext(c.k)
+ rootCtx := rootProcArgs.NewContext(c.k)
mns, err := c.createMountNamespaceVFS2(rootCtx, conf, rootCreds)
if err != nil {
@@ -213,7 +213,7 @@ func (c *containerMounter) mountAll(conf *config.Config, procArgs *kernel.Create
// createMountNamespaceVFS2 creates the container's root mount and namespace.
func (c *containerMounter) createMountNamespaceVFS2(ctx context.Context, conf *config.Config, creds *auth.Credentials) (*vfs.MountNamespace, error) {
fd := c.fds.remove()
- data := p9MountData(fd, conf.FileAccess, true /* vfs2 */)
+ data := goferMountData(fd, conf.FileAccess, "/", true /* vfs2 */, conf.Lisafs)
// We can't check for overlayfs here because sandbox is chroot'ed and gofer
// can only send mount options for specs.Mounts (specs.Root is missing
@@ -520,7 +520,7 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo
// but unlikely to be correct in this context.
return "", nil, false, fmt.Errorf("9P mount requires a connection FD")
}
- data = p9MountData(m.fd, c.getMountAccessType(conf, m.mount), true /* vfs2 */)
+ data = goferMountData(m.fd, c.getMountAccessType(conf, m.mount), m.mount.Destination, true /* vfs2 */, conf.Lisafs)
internalData = gofer.InternalFilesystemOptions{
UniqueID: m.mount.Destination,
}
diff --git a/runsc/cgroup/BUILD b/runsc/cgroup/BUILD
index f7e892584..d3aec1fff 100644
--- a/runsc/cgroup/BUILD
+++ b/runsc/cgroup/BUILD
@@ -9,6 +9,7 @@ go_library(
deps = [
"//pkg/cleanup",
"//pkg/log",
+ "//pkg/sync",
"@com_github_cenkalti_backoff//:go_default_library",
"@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/runsc/cgroup/cgroup.go b/runsc/cgroup/cgroup.go
index 5dbf14376..7a0f0694f 100644
--- a/runsc/cgroup/cgroup.go
+++ b/runsc/cgroup/cgroup.go
@@ -19,7 +19,6 @@ package cgroup
import (
"bufio"
"context"
- "errors"
"fmt"
"io"
"io/ioutil"
@@ -34,6 +33,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/cleanup"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
)
const (
@@ -104,17 +104,21 @@ func setOptionalValueUint16(path, name string, val *uint16) error {
func setValue(path, name, data string) error {
fullpath := filepath.Join(path, name)
+ log.Debugf("Setting %q to %q", fullpath, data)
+ return writeFile(fullpath, []byte(data), 0700)
+}
- // Retry writes on EINTR; see:
- // https://github.com/golang/go/issues/38033
- for {
- err := ioutil.WriteFile(fullpath, []byte(data), 0700)
- if err == nil {
- return nil
- } else if !errors.Is(err, unix.EINTR) {
- return err
- }
+// writeFile is similar to ioutil.WriteFile() but doesn't create the file if it
+// doesn't exist.
+func writeFile(path string, data []byte, perm os.FileMode) error {
+ f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC, perm)
+ if err != nil {
+ return err
}
+ defer f.Close()
+
+ _, err = f.Write(data)
+ return err
}
func getValue(path, name string) (string, error) {
@@ -155,15 +159,8 @@ func fillFromAncestor(path string) (string, error) {
return "", err
}
- // Retry writes on EINTR; see:
- // https://github.com/golang/go/issues/38033
- for {
- err := ioutil.WriteFile(path, []byte(val), 0700)
- if err == nil {
- break
- } else if !errors.Is(err, unix.EINTR) {
- return "", err
- }
+ if err := writeFile(path, []byte(val), 0700); err != nil {
+ return "", nil
}
return val, nil
}
@@ -309,7 +306,13 @@ func NewFromSpec(spec *specs.Spec) (*Cgroup, error) {
if spec.Linux == nil || spec.Linux.CgroupsPath == "" {
return nil, nil
}
- return new("self", spec.Linux.CgroupsPath)
+ return NewFromPath(spec.Linux.CgroupsPath)
+}
+
+// NewFromPath creates a new Cgroup instance from the specified relative path.
+// Cgroup paths are loaded based on the current process.
+func NewFromPath(cgroupsPath string) (*Cgroup, error) {
+ return new("self", cgroupsPath)
}
// NewFromPid loads cgroup for the given process.
@@ -365,21 +368,20 @@ func (c *Cgroup) Install(res *specs.LinuxResources) error {
}
for _, key := range missing {
ctrlr := controllers[key]
- path := c.MakePath(key)
- log.Debugf("Creating cgroup %q: %q", key, path)
- if err := os.MkdirAll(path, 0755); err != nil {
- if ctrlr.optional() && errors.Is(err, unix.EROFS) {
- if err := ctrlr.skip(res); err != nil {
- return err
- }
- log.Infof("Skipping cgroup %q", key)
- continue
+
+ if skip, err := c.createController(key); skip && ctrlr.optional() {
+ if err := ctrlr.skip(res); err != nil {
+ return err
}
+ log.Infof("Skipping cgroup %q, err: %v", key, err)
+ continue
+ } else if err != nil {
return err
}
// Only set controllers that were created by me.
c.Own[key] = true
+ path := c.MakePath(key)
if err := ctrlr.set(res, path); err != nil {
return err
}
@@ -388,10 +390,29 @@ func (c *Cgroup) Install(res *specs.LinuxResources) error {
return nil
}
+// createController creates the controller directory, checking that the
+// controller is enabled in the system. It returns a boolean indicating whether
+// the controller should be skipped (e.g. controller is disabled). In case it
+// should be skipped, it also returns the error it got.
+func (c *Cgroup) createController(name string) (bool, error) {
+ ctrlrPath := filepath.Join(cgroupRoot, name)
+ if _, err := os.Stat(ctrlrPath); err != nil {
+ return os.IsNotExist(err), err
+ }
+
+ path := c.MakePath(name)
+ log.Debugf("Creating cgroup %q: %q", name, path)
+ if err := os.MkdirAll(path, 0755); err != nil {
+ return false, err
+ }
+ return false, nil
+}
+
// Uninstall removes the settings done in Install(). If cgroup path already
// existed when Install() was called, Uninstall is a noop.
func (c *Cgroup) Uninstall() error {
log.Debugf("Deleting cgroup %q", c.Name)
+ wait := sync.WaitGroupErr{}
for key := range controllers {
if !c.Own[key] {
// cgroup is managed by caller, don't touch it.
@@ -400,9 +421,8 @@ func (c *Cgroup) Uninstall() error {
path := c.MakePath(key)
log.Debugf("Removing cgroup controller for key=%q path=%q", key, path)
- // If we try to remove the cgroup too soon after killing the
- // sandbox we might get EBUSY, so we retry for a few seconds
- // until it succeeds.
+ // If we try to remove the cgroup too soon after killing the sandbox we
+ // might get EBUSY, so we retry for a few seconds until it succeeds.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
b := backoff.WithContext(backoff.NewConstantBackOff(100*time.Millisecond), ctx)
@@ -413,11 +433,18 @@ func (c *Cgroup) Uninstall() error {
}
return err
}
- if err := backoff.Retry(fn, b); err != nil {
- return fmt.Errorf("removing cgroup path %q: %w", path, err)
- }
+ // Run deletions in parallel to remove all directories even if there are
+ // failures/timeouts in other directories.
+ wait.Add(1)
+ go func() {
+ defer wait.Done()
+ if err := backoff.Retry(fn, b); err != nil {
+ wait.ReportError(fmt.Errorf("removing cgroup path %q: %w", path, err))
+ return
+ }
+ }()
}
- return nil
+ return wait.Error()
}
// Join adds the current process to the all controllers. Returns function that
diff --git a/runsc/cgroup/cgroup_test.go b/runsc/cgroup/cgroup_test.go
index 1431b4e8f..0b6a5431b 100644
--- a/runsc/cgroup/cgroup_test.go
+++ b/runsc/cgroup/cgroup_test.go
@@ -129,6 +129,18 @@ func boolPtr(v bool) *bool {
return &v
}
+func createDir(dir string, contents map[string]string) error {
+ for name := range contents {
+ path := filepath.Join(dir, name)
+ f, err := os.Create(path)
+ if err != nil {
+ return err
+ }
+ f.Close()
+ }
+ return nil
+}
+
func checkDir(t *testing.T, dir string, contents map[string]string) {
all, err := ioutil.ReadDir(dir)
if err != nil {
@@ -254,6 +266,9 @@ func TestBlockIO(t *testing.T) {
t.Fatalf("error creating temporary directory: %v", err)
}
defer os.RemoveAll(dir)
+ if err := createDir(dir, tc.wants); err != nil {
+ t.Fatalf("createDir(): %v", err)
+ }
spec := &specs.LinuxResources{
BlockIO: tc.spec,
@@ -304,6 +319,9 @@ func TestCPU(t *testing.T) {
t.Fatalf("error creating temporary directory: %v", err)
}
defer os.RemoveAll(dir)
+ if err := createDir(dir, tc.wants); err != nil {
+ t.Fatalf("createDir(): %v", err)
+ }
spec := &specs.LinuxResources{
CPU: tc.spec,
@@ -343,6 +361,9 @@ func TestCPUSet(t *testing.T) {
t.Fatalf("error creating temporary directory: %v", err)
}
defer os.RemoveAll(dir)
+ if err := createDir(dir, tc.wants); err != nil {
+ t.Fatalf("createDir(): %v", err)
+ }
spec := &specs.LinuxResources{
CPU: tc.spec,
@@ -481,6 +502,9 @@ func TestHugeTlb(t *testing.T) {
t.Fatalf("error creating temporary directory: %v", err)
}
defer os.RemoveAll(dir)
+ if err := createDir(dir, tc.wants); err != nil {
+ t.Fatalf("createDir(): %v", err)
+ }
spec := &specs.LinuxResources{
HugepageLimits: tc.spec,
@@ -542,6 +566,9 @@ func TestMemory(t *testing.T) {
t.Fatalf("error creating temporary directory: %v", err)
}
defer os.RemoveAll(dir)
+ if err := createDir(dir, tc.wants); err != nil {
+ t.Fatalf("createDir(): %v", err)
+ }
spec := &specs.LinuxResources{
Memory: tc.spec,
@@ -584,6 +611,9 @@ func TestNetworkClass(t *testing.T) {
t.Fatalf("error creating temporary directory: %v", err)
}
defer os.RemoveAll(dir)
+ if err := createDir(dir, tc.wants); err != nil {
+ t.Fatalf("createDir(): %v", err)
+ }
spec := &specs.LinuxResources{
Network: tc.spec,
@@ -631,6 +661,9 @@ func TestNetworkPriority(t *testing.T) {
t.Fatalf("error creating temporary directory: %v", err)
}
defer os.RemoveAll(dir)
+ if err := createDir(dir, tc.wants); err != nil {
+ t.Fatalf("createDir(): %v", err)
+ }
spec := &specs.LinuxResources{
Network: tc.spec,
@@ -671,6 +704,9 @@ func TestPids(t *testing.T) {
t.Fatalf("error creating temporary directory: %v", err)
}
defer os.RemoveAll(dir)
+ if err := createDir(dir, tc.wants); err != nil {
+ t.Fatalf("createDir(): %v", err)
+ }
spec := &specs.LinuxResources{
Pids: tc.spec,
diff --git a/runsc/cli/main.go b/runsc/cli/main.go
index 3556d7665..359286da7 100644
--- a/runsc/cli/main.go
+++ b/runsc/cli/main.go
@@ -228,7 +228,8 @@ func Main(version string) {
log.Infof("\t\tFileAccess: %v, overlay: %t", conf.FileAccess, conf.Overlay)
log.Infof("\t\tNetwork: %v, logging: %t", conf.Network, conf.LogPackets)
log.Infof("\t\tStrace: %t, max size: %d, syscalls: %s", conf.Strace, conf.StraceLogSize, conf.StraceSyscalls)
- log.Infof("\t\tVFS2 enabled: %v", conf.VFS2)
+ log.Infof("\t\tVFS2 enabled: %t, LISAFS: %t", conf.VFS2, conf.Lisafs)
+ log.Infof("\t\tDebug: %v", conf.Debug)
log.Infof("***************************")
if conf.TestOnlyAllowRunAsCurrentUserWithoutChroot {
diff --git a/runsc/cmd/boot.go b/runsc/cmd/boot.go
index f5c9821b2..e33a7f3cb 100644
--- a/runsc/cmd/boot.go
+++ b/runsc/cmd/boot.go
@@ -79,6 +79,26 @@ type Boot struct {
// sandbox (e.g. gofer) and sent through this FD.
mountsFD int
+ // profileBlockFD is the file descriptor to write a block profile to.
+ // Valid if >= 0.
+ profileBlockFD int
+
+ // profileCPUFD is the file descriptor to write a CPU profile to.
+ // Valid if >= 0.
+ profileCPUFD int
+
+ // profileHeapFD is the file descriptor to write a heap profile to.
+ // Valid if >= 0.
+ profileHeapFD int
+
+ // profileMutexFD is the file descriptor to write a mutex profile to.
+ // Valid if >= 0.
+ profileMutexFD int
+
+ // traceFD is the file descriptor to write a Go execution trace to.
+ // Valid if >= 0.
+ traceFD int
+
// pidns is set if the sandbox is in its own pid namespace.
pidns bool
@@ -119,6 +139,11 @@ func (b *Boot) SetFlags(f *flag.FlagSet) {
f.IntVar(&b.userLogFD, "user-log-fd", 0, "file descriptor to write user logs to. 0 means no logging.")
f.IntVar(&b.startSyncFD, "start-sync-fd", -1, "required FD to used to synchronize sandbox startup")
f.IntVar(&b.mountsFD, "mounts-fd", -1, "mountsFD is the file descriptor to read list of mounts after they have been resolved (direct paths, no symlinks).")
+ f.IntVar(&b.profileBlockFD, "profile-block-fd", -1, "file descriptor to write block profile to. -1 disables profiling.")
+ f.IntVar(&b.profileCPUFD, "profile-cpu-fd", -1, "file descriptor to write CPU profile to. -1 disables profiling.")
+ f.IntVar(&b.profileHeapFD, "profile-heap-fd", -1, "file descriptor to write heap profile to. -1 disables profiling.")
+ f.IntVar(&b.profileMutexFD, "profile-mutex-fd", -1, "file descriptor to write mutex profile to. -1 disables profiling.")
+ f.IntVar(&b.traceFD, "trace-fd", -1, "file descriptor to write Go execution trace to. -1 disables tracing.")
f.BoolVar(&b.attached, "attached", false, "if attached is true, kills the sandbox process when the parent process terminates")
}
@@ -213,16 +238,21 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
// Create the loader.
bootArgs := boot.Args{
- ID: f.Arg(0),
- Spec: spec,
- Conf: conf,
- ControllerFD: b.controllerFD,
- Device: os.NewFile(uintptr(b.deviceFD), "platform device"),
- GoferFDs: b.ioFDs.GetArray(),
- StdioFDs: b.stdioFDs.GetArray(),
- NumCPU: b.cpuNum,
- TotalMem: b.totalMem,
- UserLogFD: b.userLogFD,
+ ID: f.Arg(0),
+ Spec: spec,
+ Conf: conf,
+ ControllerFD: b.controllerFD,
+ Device: os.NewFile(uintptr(b.deviceFD), "platform device"),
+ GoferFDs: b.ioFDs.GetArray(),
+ StdioFDs: b.stdioFDs.GetArray(),
+ NumCPU: b.cpuNum,
+ TotalMem: b.totalMem,
+ UserLogFD: b.userLogFD,
+ ProfileBlockFD: b.profileBlockFD,
+ ProfileCPUFD: b.profileCPUFD,
+ ProfileHeapFD: b.profileHeapFD,
+ ProfileMutexFD: b.profileMutexFD,
+ TraceFD: b.traceFD,
}
l, err := boot.New(bootArgs)
if err != nil {
diff --git a/runsc/cmd/debug.go b/runsc/cmd/debug.go
index f773ccca0..318753728 100644
--- a/runsc/cmd/debug.go
+++ b/runsc/cmd/debug.go
@@ -37,9 +37,9 @@ type Debug struct {
pid int
stacks bool
signal int
- profileHeap string
- profileCPU string
profileBlock string
+ profileCPU string
+ profileHeap string
profileMutex string
trace string
strace string
@@ -70,9 +70,9 @@ func (*Debug) Usage() string {
func (d *Debug) SetFlags(f *flag.FlagSet) {
f.IntVar(&d.pid, "pid", 0, "sandbox process ID. Container ID is not necessary if this is set")
f.BoolVar(&d.stacks, "stacks", false, "if true, dumps all sandbox stacks to the log")
- f.StringVar(&d.profileHeap, "profile-heap", "", "writes heap profile to the given file.")
- f.StringVar(&d.profileCPU, "profile-cpu", "", "writes CPU profile to the given file.")
f.StringVar(&d.profileBlock, "profile-block", "", "writes block profile to the given file.")
+ f.StringVar(&d.profileCPU, "profile-cpu", "", "writes CPU profile to the given file.")
+ f.StringVar(&d.profileHeap, "profile-heap", "", "writes heap profile to the given file.")
f.StringVar(&d.profileMutex, "profile-mutex", "", "writes mutex profile to the given file.")
f.DurationVar(&d.delay, "delay", time.Hour, "amount of time to delay for collecting heap and goroutine profiles.")
f.DurationVar(&d.duration, "duration", time.Hour, "amount of time to wait for CPU and trace profiles.")
@@ -90,6 +90,13 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
var c *container.Container
conf := args[0].(*config.Config)
+ if conf.ProfileBlock != "" || conf.ProfileCPU != "" || conf.ProfileHeap != "" || conf.ProfileMutex != "" {
+ return Errorf("global -profile-{block,cpu,heap,mutex} flags have no effect on runsc debug. Pass runsc debug -profile-{block,cpu,heap,mutex} instead")
+ }
+ if conf.TraceFile != "" {
+ return Errorf("global -trace flag has no effect on runsc debug. Pass runsc debug -trace instead")
+ }
+
if d.pid == 0 {
// No pid, container ID must have been provided.
if f.NArg() != 1 {
@@ -219,19 +226,19 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
// Open profiling files.
var (
- heapFile *os.File
- cpuFile *os.File
- traceFile *os.File
blockFile *os.File
+ cpuFile *os.File
+ heapFile *os.File
mutexFile *os.File
+ traceFile *os.File
)
- if d.profileHeap != "" {
- f, err := os.OpenFile(d.profileHeap, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
+ if d.profileBlock != "" {
+ f, err := os.OpenFile(d.profileBlock, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
- return Errorf("error opening heap profile output: %v", err)
+ return Errorf("error opening blocking profile output: %v", err)
}
defer f.Close()
- heapFile = f
+ blockFile = f
}
if d.profileCPU != "" {
f, err := os.OpenFile(d.profileCPU, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
@@ -241,20 +248,13 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
defer f.Close()
cpuFile = f
}
- if d.trace != "" {
- f, err := os.OpenFile(d.trace, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
- if err != nil {
- return Errorf("error opening trace profile output: %v", err)
- }
- traceFile = f
- }
- if d.profileBlock != "" {
- f, err := os.OpenFile(d.profileBlock, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
+ if d.profileHeap != "" {
+ f, err := os.OpenFile(d.profileHeap, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
- return Errorf("error opening blocking profile output: %v", err)
+ return Errorf("error opening heap profile output: %v", err)
}
defer f.Close()
- blockFile = f
+ heapFile = f
}
if d.profileMutex != "" {
f, err := os.OpenFile(d.profileMutex, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
@@ -264,21 +264,28 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
defer f.Close()
mutexFile = f
}
+ if d.trace != "" {
+ f, err := os.OpenFile(d.trace, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
+ if err != nil {
+ return Errorf("error opening trace profile output: %v", err)
+ }
+ traceFile = f
+ }
// Collect profiles.
var (
wg sync.WaitGroup
- heapErr error
- cpuErr error
- traceErr error
blockErr error
+ cpuErr error
+ heapErr error
mutexErr error
+ traceErr error
)
- if heapFile != nil {
+ if blockFile != nil {
wg.Add(1)
go func() {
defer wg.Done()
- heapErr = c.Sandbox.HeapProfile(heapFile, d.delay)
+ blockErr = c.Sandbox.BlockProfile(blockFile, d.duration)
}()
}
if cpuFile != nil {
@@ -288,25 +295,25 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
cpuErr = c.Sandbox.CPUProfile(cpuFile, d.duration)
}()
}
- if traceFile != nil {
+ if heapFile != nil {
wg.Add(1)
go func() {
defer wg.Done()
- traceErr = c.Sandbox.Trace(traceFile, d.duration)
+ heapErr = c.Sandbox.HeapProfile(heapFile, d.delay)
}()
}
- if blockFile != nil {
+ if mutexFile != nil {
wg.Add(1)
go func() {
defer wg.Done()
- blockErr = c.Sandbox.BlockProfile(blockFile, d.duration)
+ mutexErr = c.Sandbox.MutexProfile(mutexFile, d.duration)
}()
}
- if mutexFile != nil {
+ if traceFile != nil {
wg.Add(1)
go func() {
defer wg.Done()
- mutexErr = c.Sandbox.MutexProfile(mutexFile, d.duration)
+ traceErr = c.Sandbox.Trace(traceFile, d.duration)
}()
}
@@ -339,31 +346,31 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
// Collect all errors.
errorCount := 0
- if heapErr != nil {
+ if blockErr != nil {
errorCount++
- log.Infof("error collecting heap profile: %v", heapErr)
- os.Remove(heapFile.Name())
+ log.Infof("error collecting block profile: %v", blockErr)
+ os.Remove(blockFile.Name())
}
if cpuErr != nil {
errorCount++
log.Infof("error collecting cpu profile: %v", cpuErr)
os.Remove(cpuFile.Name())
}
- if traceErr != nil {
- errorCount++
- log.Infof("error collecting trace profile: %v", traceErr)
- os.Remove(traceFile.Name())
- }
- if blockErr != nil {
+ if heapErr != nil {
errorCount++
- log.Infof("error collecting block profile: %v", blockErr)
- os.Remove(blockFile.Name())
+ log.Infof("error collecting heap profile: %v", heapErr)
+ os.Remove(heapFile.Name())
}
if mutexErr != nil {
errorCount++
log.Infof("error collecting mutex profile: %v", mutexErr)
os.Remove(mutexFile.Name())
}
+ if traceErr != nil {
+ errorCount++
+ log.Infof("error collecting trace profile: %v", traceErr)
+ os.Remove(traceFile.Name())
+ }
if errorCount > 0 {
return subcommands.ExitFailure
diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go
index 2193e9040..c65e0267a 100644
--- a/runsc/cmd/gofer.go
+++ b/runsc/cmd/gofer.go
@@ -71,8 +71,8 @@ func (*Gofer) Name() string {
}
// Synopsis implements subcommands.Command.
-func (*Gofer) Synopsis() string {
- return "launch a gofer process that serves files over 9P protocol (internal use only)"
+func (g *Gofer) Synopsis() string {
+ return fmt.Sprintf("launch a gofer process that serves files over the protocol (9P or lisafs) defined in the config (internal use only)")
}
// Usage implements subcommands.Command.
@@ -83,7 +83,7 @@ func (*Gofer) Usage() string {
// SetFlags implements subcommands.Command.
func (g *Gofer) SetFlags(f *flag.FlagSet) {
f.StringVar(&g.bundleDir, "bundle", "", "path to the root of the bundle directory, defaults to the current directory")
- f.Var(&g.ioFDs, "io-fds", "list of FDs to connect 9P servers. They must follow this order: root first, then mounts as defined in the spec")
+ f.Var(&g.ioFDs, "io-fds", "list of FDs to connect gofer servers. They must follow this order: root first, then mounts as defined in the spec")
f.BoolVar(&g.applyCaps, "apply-caps", true, "if true, apply capabilities to restrict what the Gofer process can do")
f.BoolVar(&g.setUpRoot, "setup-root", true, "if true, set up an empty root for the process")
f.IntVar(&g.specFD, "spec-fd", -1, "required fd with the container spec")
@@ -160,10 +160,98 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
}
log.Infof("Process chroot'd to %q", root)
+ // Initialize filters.
+ if conf.FSGoferHostUDS {
+ filter.InstallUDSFilters()
+ }
+
+ if conf.Verity {
+ filter.InstallXattrFilters()
+ }
+
+ if err := filter.Install(); err != nil {
+ Fatalf("installing seccomp filters: %v", err)
+ }
+
+ if conf.Lisafs {
+ return g.serveLisafs(spec, conf, root)
+ }
+ return g.serve9P(spec, conf, root)
+}
+
+func newSocket(ioFD int) *unet.Socket {
+ socket, err := unet.NewSocket(ioFD)
+ if err != nil {
+ Fatalf("creating server on FD %d: %v", ioFD, err)
+ }
+ return socket
+}
+
+func (g *Gofer) serveLisafs(spec *specs.Spec, conf *config.Config, root string) subcommands.ExitStatus {
+ type connectionConfig struct {
+ sock *unet.Socket
+ readonly bool
+ }
+ cfgs := make([]connectionConfig, 0, len(spec.Mounts)+1)
+ server := fsgofer.NewLisafsServer(fsgofer.Config{
+ // These are global options. Ignore readonly configuration, that is set on
+ // a per connection basis.
+ HostUDS: conf.FSGoferHostUDS,
+ EnableVerityXattr: conf.Verity,
+ })
+
+ // Start with root mount, then add any other additional mount as needed.
+ cfgs = append(cfgs, connectionConfig{
+ sock: newSocket(g.ioFDs[0]),
+ readonly: spec.Root.Readonly || conf.Overlay,
+ })
+ log.Infof("Serving %q mapped to %q on FD %d (ro: %t)", "/", root, g.ioFDs[0], cfgs[0].readonly)
+
+ mountIdx := 1 // first one is the root
+ for _, m := range spec.Mounts {
+ if !specutils.IsGoferMount(m, conf.VFS2) {
+ continue
+ }
+
+ if !filepath.IsAbs(m.Destination) {
+ Fatalf("mount destination must be absolute: %q", m.Destination)
+ }
+ if mountIdx >= len(g.ioFDs) {
+ Fatalf("no FD found for mount. Did you forget --io-fd? FDs: %d, Mount: %+v", len(g.ioFDs), m)
+ }
+
+ cfgs = append(cfgs, connectionConfig{
+ sock: newSocket(g.ioFDs[mountIdx]),
+ readonly: isReadonlyMount(m.Options) || conf.Overlay,
+ })
+
+ log.Infof("Serving %q mapped on FD %d (ro: %t)", m.Destination, g.ioFDs[mountIdx], cfgs[mountIdx].readonly)
+ mountIdx++
+ }
+
+ if mountIdx != len(g.ioFDs) {
+ Fatalf("too many FDs passed for mounts. mounts: %d, FDs: %d", mountIdx, len(g.ioFDs))
+ }
+ cfgs = cfgs[:mountIdx]
+
+ for _, cfg := range cfgs {
+ conn, err := server.CreateConnection(cfg.sock, cfg.readonly)
+ if err != nil {
+ Fatalf("starting connection on FD %d for gofer mount failed: %v", cfg.sock.FD(), err)
+ }
+ server.StartConnection(conn)
+ }
+ server.Wait()
+ log.Infof("All lisafs servers exited.")
+ return subcommands.ExitSuccess
+}
+
+func (g *Gofer) serve9P(spec *specs.Spec, conf *config.Config, root string) subcommands.ExitStatus {
// Start with root mount, then add any other additional mount as needed.
ats := make([]p9.Attacher, 0, len(spec.Mounts)+1)
ap, err := fsgofer.NewAttachPoint("/", fsgofer.Config{
ROMount: spec.Root.Readonly || conf.Overlay,
+ HostUDS: conf.FSGoferHostUDS,
EnableVerityXattr: conf.Verity,
})
if err != nil {
@@ -174,7 +262,7 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
mountIdx := 1 // first one is the root
for _, m := range spec.Mounts {
- if specutils.Is9PMount(m, conf.VFS2) {
+ if specutils.IsGoferMount(m, conf.VFS2) {
cfg := fsgofer.Config{
ROMount: isReadonlyMount(m.Options) || conf.Overlay,
HostUDS: conf.FSGoferHostUDS,
@@ -197,26 +285,9 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
Fatalf("too many FDs passed for mounts. mounts: %d, FDs: %d", mountIdx, len(g.ioFDs))
}
- if conf.FSGoferHostUDS {
- filter.InstallUDSFilters()
- }
-
- if conf.Verity {
- filter.InstallXattrFilters()
- }
-
- if err := filter.Install(); err != nil {
- Fatalf("installing seccomp filters: %v", err)
- }
-
- runServers(ats, g.ioFDs)
- return subcommands.ExitSuccess
-}
-
-func runServers(ats []p9.Attacher, ioFDs []int) {
// Run the loops and wait for all to exit.
var wg sync.WaitGroup
- for i, ioFD := range ioFDs {
+ for i, ioFD := range g.ioFDs {
wg.Add(1)
go func(ioFD int, at p9.Attacher) {
socket, err := unet.NewSocket(ioFD)
@@ -232,6 +303,7 @@ func runServers(ats []p9.Attacher, ioFDs []int) {
}
wg.Wait()
log.Infof("All 9P servers exited.")
+ return subcommands.ExitSuccess
}
func (g *Gofer) writeMounts(mounts []specs.Mount) error {
@@ -362,7 +434,7 @@ func setupRootFS(spec *specs.Spec, conf *config.Config) error {
// creates directories as needed.
func setupMounts(conf *config.Config, mounts []specs.Mount, root, procPath string) error {
for _, m := range mounts {
- if !specutils.Is9PMount(m, conf.VFS2) {
+ if !specutils.IsGoferMount(m, conf.VFS2) {
continue
}
@@ -402,7 +474,7 @@ func setupMounts(conf *config.Config, mounts []specs.Mount, root, procPath strin
func resolveMounts(conf *config.Config, mounts []specs.Mount, root string) ([]specs.Mount, error) {
cleanMounts := make([]specs.Mount, 0, len(mounts))
for _, m := range mounts {
- if !specutils.Is9PMount(m, conf.VFS2) {
+ if !specutils.IsGoferMount(m, conf.VFS2) {
cleanMounts = append(cleanMounts, m)
continue
}
diff --git a/runsc/cmd/run.go b/runsc/cmd/run.go
index 722181aff..da11c9d06 100644
--- a/runsc/cmd/run.go
+++ b/runsc/cmd/run.go
@@ -68,7 +68,14 @@ func (r *Run) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s
waitStatus := args[1].(*unix.WaitStatus)
if conf.Rootless {
- return Errorf("Rootless mode not supported with %q", r.Name())
+ if conf.Network == config.NetworkSandbox {
+ return Errorf("sandbox network isn't supported with --rootless, use --network=none or --network=host")
+ }
+
+ if err := specutils.MaybeRunAsRoot(); err != nil {
+ return Errorf("Error executing inside namespace: %v", err)
+ }
+ // Execution will continue here if no more capabilities are needed...
}
bundleDir := r.bundleDir
diff --git a/runsc/cmd/spec.go b/runsc/cmd/spec.go
index 55194e641..3a7f2a2c4 100644
--- a/runsc/cmd/spec.go
+++ b/runsc/cmd/spec.go
@@ -30,7 +30,6 @@ func writeSpec(w io.Writer, cwd string, netns string, args []string) error {
spec := &specs.Spec{
Version: "1.0.0",
Process: &specs.Process{
- Terminal: true,
User: specs.User{
UID: 0,
GID: 0,
diff --git a/runsc/config/config.go b/runsc/config/config.go
index a230baa29..91142888f 100644
--- a/runsc/config/config.go
+++ b/runsc/config/config.go
@@ -140,6 +140,26 @@ type Config struct {
// ProfileEnable is set to prepare the sandbox to be profiled.
ProfileEnable bool `flag:"profile"`
+ // ProfileBlock collects a block profile to the passed file for the
+ // duration of the container execution. Requires ProfileEnabled.
+ ProfileBlock string `flag:"profile-block"`
+
+ // ProfileCPU collects a CPU profile to the passed file for the
+ // duration of the container execution. Requires ProfileEnabled.
+ ProfileCPU string `flag:"profile-cpu"`
+
+ // ProfileHeap collects a heap profile to the passed file for the
+ // duration of the container execution. Requires ProfileEnabled.
+ ProfileHeap string `flag:"profile-heap"`
+
+ // ProfileMutex collects a mutex profile to the passed file for the
+ // duration of the container execution. Requires ProfileEnabled.
+ ProfileMutex string `flag:"profile-mutex"`
+
+ // TraceFile collects a Go runtime execution trace to the passed file
+ // for the duration of the container execution.
+ TraceFile string `flag:"trace"`
+
// Controls defines the controls that may be enabled.
Controls controlConfig `flag:"controls"`
@@ -173,6 +193,9 @@ type Config struct {
// Enables VFS2.
VFS2 bool `flag:"vfs2"`
+ // Enable lisafs.
+ Lisafs bool `flag:"lisafs"`
+
// Enables FUSE usage.
FUSE bool `flag:"fuse"`
@@ -207,6 +230,21 @@ func (c *Config) validate() error {
if c.NumNetworkChannels <= 0 {
return fmt.Errorf("num_network_channels must be > 0, got: %d", c.NumNetworkChannels)
}
+ // Require profile flags to explicitly opt-in to profiling with
+ // -profile rather than implying it since these options have security
+ // implications.
+ if c.ProfileBlock != "" && !c.ProfileEnable {
+ return fmt.Errorf("profile-block flag requires enabling profiling with profile flag")
+ }
+ if c.ProfileCPU != "" && !c.ProfileEnable {
+ return fmt.Errorf("profile-cpu flag requires enabling profiling with profile flag")
+ }
+ if c.ProfileHeap != "" && !c.ProfileEnable {
+ return fmt.Errorf("profile-heap flag requires enabling profiling with profile flag")
+ }
+ if c.ProfileMutex != "" && !c.ProfileEnable {
+ return fmt.Errorf("profile-mutex flag requires enabling profiling with profile flag")
+ }
return nil
}
diff --git a/runsc/config/flags.go b/runsc/config/flags.go
index cc5aba474..663b20d9c 100644
--- a/runsc/config/flags.go
+++ b/runsc/config/flags.go
@@ -63,6 +63,11 @@ func RegisterFlags() {
flag.Var(watchdogActionPtr(watchdog.LogWarning), "watchdog-action", "sets what action the watchdog takes when triggered: log (default), panic.")
flag.Int("panic-signal", -1, "register signal handling that panics. Usually set to SIGUSR2(12) to troubleshoot hangs. -1 disables it.")
flag.Bool("profile", false, "prepares the sandbox to use Golang profiler. Note that enabling profiler loosens the seccomp protection added to the sandbox (DO NOT USE IN PRODUCTION).")
+ flag.String("profile-block", "", "collects a block profile to this file path for the duration of the container execution. Requires -profile=true.")
+ flag.String("profile-cpu", "", "collects a CPU profile to this file path for the duration of the container execution. Requires -profile=true.")
+ flag.String("profile-heap", "", "collects a heap profile to this file path for the duration of the container execution. Requires -profile=true.")
+ flag.String("profile-mutex", "", "collects a mutex profile to this file path for the duration of the container execution. Requires -profile=true.")
+ flag.String("trace", "", "collects a Go runtime execution trace to this file path for the duration of the container execution.")
flag.Bool("rootless", false, "it allows the sandbox to be started with a user that is not root. Sandbox and Gofer processes may run with same privileges as current user.")
flag.Var(leakModePtr(refs.NoLeakChecking), "ref-leak-mode", "sets reference leak check mode: disabled (default), log-names, log-traces.")
flag.Bool("cpu-num-from-quota", false, "set cpu number to cpu quota (least integer greater or equal to quota value, but not less than 2)")
@@ -75,8 +80,9 @@ func RegisterFlags() {
flag.Bool("overlay", false, "wrap filesystem mounts with writable overlay. All modifications are stored in memory inside the sandbox.")
flag.Bool("verity", false, "specifies whether a verity file system will be mounted.")
flag.Bool("fsgofer-host-uds", false, "allow the gofer to mount Unix Domain Sockets.")
- flag.Bool("vfs2", false, "enables VFSv2. This uses the new VFS layer that is faster than the previous one.")
+ flag.Bool("vfs2", true, "enables VFSv2. This uses the new VFS layer that is faster than the previous one.")
flag.Bool("fuse", false, "TEST ONLY; use while FUSE in VFSv2 is landing. This allows the use of the new experimental FUSE filesystem.")
+ flag.Bool("lisafs", false, "Enables lisafs protocol instead of 9P. This is only effective with VFS2.")
flag.Bool("cgroupfs", false, "Automatically mount cgroupfs.")
// Flags that control sandbox runtime behavior: network related.
diff --git a/runsc/container/BUILD b/runsc/container/BUILD
index 5314549d6..4e744e604 100644
--- a/runsc/container/BUILD
+++ b/runsc/container/BUILD
@@ -19,7 +19,7 @@ go_library(
"//pkg/cleanup",
"//pkg/log",
"//pkg/sentry/control",
- "//pkg/sentry/sighandling",
+ "//pkg/sighandling",
"//pkg/sync",
"//runsc/boot",
"//runsc/cgroup",
diff --git a/runsc/container/container.go b/runsc/container/container.go
index 50b0dd5e7..77a0f7eba 100644
--- a/runsc/container/container.go
+++ b/runsc/container/container.go
@@ -35,7 +35,7 @@ import (
"gvisor.dev/gvisor/pkg/cleanup"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/control"
- "gvisor.dev/gvisor/pkg/sentry/sighandling"
+ "gvisor.dev/gvisor/pkg/sighandling"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/cgroup"
"gvisor.dev/gvisor/runsc/config"
@@ -44,6 +44,8 @@ import (
"gvisor.dev/gvisor/runsc/specutils"
)
+const cgroupParentAnnotation = "dev.gvisor.spec.cgroup-parent"
+
// validateID validates the container id.
func validateID(id string) error {
// See libcontainer/factory_linux.go.
@@ -113,6 +115,16 @@ type Container struct {
// container is created and reset when the sandbox is destroyed.
Sandbox *sandbox.Sandbox `json:"sandbox"`
+ // CompatCgroup has the cgroup configuration for the container. For the single
+ // container case, container cgroup is set in `c.Sandbox` only. CompactCgroup
+ // is only set for multi-container, where the `c.Sandbox` cgroup represents
+ // the entire pod.
+ //
+ // Note that CompatCgroup is created only for compatibility with tools
+ // that expect container cgroups to exist. Setting limits here makes no change
+ // to the container in question.
+ CompatCgroup *cgroup.Cgroup `json:"compatCgroup"`
+
// Saver handles load from/save to the state file safely from multiple
// processes.
Saver StateFile `json:"saver"`
@@ -233,27 +245,12 @@ func New(conf *config.Config, args Args) (*Container, error) {
}
// Create and join cgroup before processes are created to ensure they are
// part of the cgroup from the start (and all their children processes).
- cg, err := cgroup.NewFromSpec(args.Spec)
+ parentCgroup, subCgroup, err := c.setupCgroupForRoot(conf, args.Spec)
if err != nil {
return nil, err
}
- if cg != nil {
- // TODO(gvisor.dev/issue/3481): Remove when cgroups v2 is supported.
- if !conf.Rootless && cgroup.IsOnlyV2() {
- return nil, fmt.Errorf("cgroups V2 is not yet supported. Enable cgroups V1 and retry")
- }
- // If there is cgroup config, install it before creating sandbox process.
- if err := cg.Install(args.Spec.Linux.Resources); err != nil {
- switch {
- case errors.Is(err, unix.EACCES) && conf.Rootless:
- log.Warningf("Skipping cgroup configuration in rootless mode: %v", err)
- cg = nil
- default:
- return nil, fmt.Errorf("configuring cgroup: %v", err)
- }
- }
- }
- if err := runInCgroup(cg, func() error {
+ c.CompatCgroup = subCgroup
+ if err := runInCgroup(parentCgroup, func() error {
ioFiles, specFile, err := c.createGoferProcess(args.Spec, conf, args.BundleDir, args.Attached)
if err != nil {
return err
@@ -269,7 +266,7 @@ func New(conf *config.Config, args Args) (*Container, error) {
UserLog: args.UserLog,
IOFiles: ioFiles,
MountsFile: specFile,
- Cgroup: cg,
+ Cgroup: parentCgroup,
Attached: args.Attached,
}
sand, err := sandbox.New(conf, sandArgs)
@@ -296,6 +293,12 @@ func New(conf *config.Config, args Args) (*Container, error) {
}
c.Sandbox = sb.Sandbox
+ subCgroup, err := c.setupCgroupForSubcontainer(conf, args.Spec)
+ if err != nil {
+ return nil, err
+ }
+ c.CompatCgroup = subCgroup
+
// If the console control socket file is provided, then create a new
// pty master/slave pair and send the TTY to the sandbox process.
var tty *os.File
@@ -781,16 +784,16 @@ func (c *Container) saveLocked() error {
// root containers), and waits for the container or sandbox and the gofer
// to stop. If any of them doesn't stop before timeout, an error is returned.
func (c *Container) stop() error {
- var cgroup *cgroup.Cgroup
+ var parentCgroup *cgroup.Cgroup
if c.Sandbox != nil {
log.Debugf("Destroying container, cid: %s", c.ID)
if err := c.Sandbox.DestroyContainer(c.ID); err != nil {
return fmt.Errorf("destroying container %q: %v", c.ID, err)
}
- // Only uninstall cgroup for sandbox stop.
+ // Only uninstall parentCgroup for sandbox stop.
if c.Sandbox.IsRootContainer(c.ID) {
- cgroup = c.Sandbox.Cgroup
+ parentCgroup = c.Sandbox.Cgroup
}
// Only set sandbox to nil after it has been told to destroy the container.
c.Sandbox = nil
@@ -809,9 +812,16 @@ func (c *Container) stop() error {
return err
}
- // Gofer is running in cgroups, so Cgroup.Uninstall has to be called after it.
- if cgroup != nil {
- if err := cgroup.Uninstall(); err != nil {
+ // Delete container cgroup if any.
+ if c.CompatCgroup != nil {
+ if err := c.CompatCgroup.Uninstall(); err != nil {
+ return err
+ }
+ }
+ // Gofer is running inside parentCgroup, so Cgroup.Uninstall has to be called
+ // after the gofer has stopped.
+ if parentCgroup != nil {
+ if err := parentCgroup.Uninstall(); err != nil {
return err
}
}
@@ -917,7 +927,7 @@ func (c *Container) createGoferProcess(spec *specs.Spec, conf *config.Config, bu
// Add root mount and then add any other additional mounts.
mountCount := 1
for _, m := range spec.Mounts {
- if specutils.Is9PMount(m, conf.VFS2) {
+ if specutils.IsGoferMount(m, conf.VFS2) {
mountCount++
}
}
@@ -1208,3 +1218,80 @@ func (c *Container) populateStats(event *boot.EventOut) {
event.Event.Data.CPU.Usage.Total = uint64(total)
return
}
+
+// setupCgroupForRoot configures and returns cgroup for the sandbox and the
+// root container. If `cgroupParentAnnotation` is set, use that path as the
+// sandbox cgroup and use Spec.Linux.CgroupsPath as the root container cgroup.
+func (c *Container) setupCgroupForRoot(conf *config.Config, spec *specs.Spec) (*cgroup.Cgroup, *cgroup.Cgroup, error) {
+ var parentCgroup *cgroup.Cgroup
+ if parentPath, ok := spec.Annotations[cgroupParentAnnotation]; ok {
+ var err error
+ parentCgroup, err = cgroup.NewFromPath(parentPath)
+ if err != nil {
+ return nil, nil, err
+ }
+ } else {
+ var err error
+ parentCgroup, err = cgroup.NewFromSpec(spec)
+ if parentCgroup == nil || err != nil {
+ return nil, nil, err
+ }
+ }
+
+ var err error
+ parentCgroup, err = cgroupInstall(conf, parentCgroup, spec.Linux.Resources)
+ if parentCgroup == nil || err != nil {
+ return nil, nil, err
+ }
+
+ subCgroup, err := c.setupCgroupForSubcontainer(conf, spec)
+ if err != nil {
+ _ = parentCgroup.Uninstall()
+ return nil, nil, err
+ }
+ return parentCgroup, subCgroup, nil
+}
+
+// setupCgroupForSubcontainer sets up empty cgroups for subcontainers. Since
+// subcontainers run exclusively inside the sandbox, subcontainer cgroups on the
+// host have no effect on them. However, some tools (e.g. cAdvisor) uses cgroups
+// paths to discover new containers and report stats for them.
+func (c *Container) setupCgroupForSubcontainer(conf *config.Config, spec *specs.Spec) (*cgroup.Cgroup, error) {
+ if isRoot(spec) {
+ if _, ok := spec.Annotations[cgroupParentAnnotation]; !ok {
+ return nil, nil
+ }
+ }
+
+ cg, err := cgroup.NewFromSpec(spec)
+ if cg == nil || err != nil {
+ return nil, err
+ }
+ // Use empty resources, just want the directory structure created.
+ return cgroupInstall(conf, cg, &specs.LinuxResources{})
+}
+
+// cgroupInstall creates cgroups dir structure and sets their respective
+// resources. In case of success, returns the cgroups instance and nil error.
+// For rootless, it's possible that cgroups operations fail, in this case the
+// error is suppressed and a nil cgroups instance is returned to indicate that
+// no cgroups was configured.
+func cgroupInstall(conf *config.Config, cg *cgroup.Cgroup, res *specs.LinuxResources) (*cgroup.Cgroup, error) {
+ // TODO(gvisor.dev/issue/3481): Remove when cgroups v2 is supported.
+ if cgroup.IsOnlyV2() {
+ if conf.Rootless {
+ return nil, nil
+ }
+ return nil, fmt.Errorf("cgroups V2 is not yet supported. Enable cgroups V1 and retry")
+ }
+ if err := cg.Install(res); err != nil {
+ switch {
+ case errors.Is(err, unix.EACCES) && conf.Rootless:
+ log.Warningf("Skipping cgroup configuration in rootless mode: %v", err)
+ return nil, nil
+ default:
+ return nil, fmt.Errorf("configuring cgroup: %v", err)
+ }
+ }
+ return cg, nil
+}
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index 681f5c1a9..94e58d921 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -407,9 +407,9 @@ var (
func configsHelper(t *testing.T, opts ...configOption) map[string]*config.Config {
// Always load the default config.
cs := make(map[string]*config.Config)
- testutil.TestConfig(t)
for _, o := range opts {
c := testutil.TestConfig(t)
+ c.VFS2 = false
switch o {
case overlay:
c.Overlay = true
@@ -2615,6 +2615,8 @@ func TestCat(t *testing.T) {
f.Close()
spec, conf := sleepSpecConf(t)
+ // TODO(gvisor.dev/issue/6742): Add VFS2 support.
+ conf.VFS2 = false
_, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
if err != nil {
@@ -2829,3 +2831,46 @@ func TestStream(t *testing.T) {
t.Errorf("out got %s, want include %s", buf, want)
}
}
+
+// TestProfile checks that profiling options generate profiles.
+func TestProfile(t *testing.T) {
+ // Perform a non-trivial amount of work so we actually capture
+ // something in the profiles.
+ spec := testutil.NewSpecWithArgs("/bin/bash", "-c", "true")
+ conf := testutil.TestConfig(t)
+ conf.ProfileEnable = true
+ conf.ProfileBlock = filepath.Join(t.TempDir(), "block.pprof")
+ conf.ProfileCPU = filepath.Join(t.TempDir(), "cpu.pprof")
+ conf.ProfileHeap = filepath.Join(t.TempDir(), "heap.pprof")
+ conf.ProfileMutex = filepath.Join(t.TempDir(), "mutex.pprof")
+ conf.TraceFile = filepath.Join(t.TempDir(), "trace.out")
+
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ Attached: true,
+ }
+
+ _, err = Run(conf, args)
+ if err != nil {
+ t.Fatalf("Creating container: %v", err)
+ }
+
+ // Basic test; simply assert that the profiles are not empty.
+ for _, name := range []string{conf.ProfileBlock, conf.ProfileCPU, conf.ProfileHeap, conf.ProfileMutex, conf.TraceFile} {
+ fi, err := os.Stat(name)
+ if err != nil {
+ t.Fatalf("Unable to stat profile file %s: %v", name, err)
+ }
+ if fi.Size() == 0 {
+ t.Errorf("Profile file %s is empty: %+v", name, fi)
+ }
+ }
+}
diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go
index 9d8022e50..a620944a3 100644
--- a/runsc/container/multi_container_test.go
+++ b/runsc/container/multi_container_test.go
@@ -1624,6 +1624,7 @@ func TestMultiContainerGoferKilled(t *testing.T) {
defer cleanup()
conf := testutil.TestConfig(t)
+ conf.VFS2 = true
conf.RootDir = rootDir
sleep := []string{"sleep", "100"}
diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD
index 3280b74fe..8d5a6d300 100644
--- a/runsc/fsgofer/BUILD
+++ b/runsc/fsgofer/BUILD
@@ -9,12 +9,16 @@ go_library(
"fsgofer_amd64_unsafe.go",
"fsgofer_arm64_unsafe.go",
"fsgofer_unsafe.go",
+ "lisafs.go",
],
visibility = ["//runsc:__subpackages__"],
deps = [
+ "//pkg/abi/linux",
"//pkg/cleanup",
"//pkg/fd",
+ "//pkg/lisafs",
"//pkg/log",
+ "//pkg/marshal/primitive",
"//pkg/p9",
"//pkg/sync",
"//pkg/syserr",
@@ -37,3 +41,15 @@ go_test(
"@org_golang_x_sys//unix:go_default_library",
],
)
+
+go_test(
+ name = "lisafs_test",
+ size = "small",
+ srcs = ["lisafs_test.go"],
+ deps = [
+ ":fsgofer",
+ "//pkg/lisafs",
+ "//pkg/lisafs/testsuite",
+ "//pkg/log",
+ ],
+)
diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go
index ee6cc97df..6cdd6d695 100644
--- a/runsc/fsgofer/fsgofer_test.go
+++ b/runsc/fsgofer/fsgofer_test.go
@@ -105,14 +105,14 @@ func testReadWrite(f p9.File, flags p9.OpenFlags, content []byte) error {
return nil
}
-type state struct {
+type fileState struct {
root *localFile
file *localFile
conf Config
fileType uint32
}
-func (s state) String() string {
+func (s fileState) String() string {
return fmt.Sprintf("type(%v)", s.fileType)
}
@@ -129,11 +129,11 @@ func typeName(fileType uint32) string {
}
}
-func runAll(t *testing.T, test func(*testing.T, state)) {
+func runAll(t *testing.T, test func(*testing.T, fileState)) {
runCustom(t, allTypes, allConfs, test)
}
-func runCustom(t *testing.T, types []uint32, confs []Config, test func(*testing.T, state)) {
+func runCustom(t *testing.T, types []uint32, confs []Config, test func(*testing.T, fileState)) {
for _, c := range confs {
for _, ft := range types {
name := fmt.Sprintf("%s/%s", configTestName(&c), typeName(ft))
@@ -159,7 +159,7 @@ func runCustom(t *testing.T, types []uint32, confs []Config, test func(*testing.
t.Fatalf("root.Walk({%q}) failed, err: %v", "symlink", err)
}
- st := state{
+ st := fileState{
root: root.(*localFile),
file: file.(*localFile),
conf: c,
@@ -227,7 +227,7 @@ func createFile(dir *localFile, name string) (*localFile, error) {
}
func TestReadWrite(t *testing.T) {
- runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) {
+ runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s fileState) {
child, err := createFile(s.file, "test")
if err != nil {
t.Fatalf("%v: createFile() failed, err: %v", s, err)
@@ -261,7 +261,7 @@ func TestReadWrite(t *testing.T) {
}
func TestCreate(t *testing.T) {
- runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) {
+ runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s fileState) {
for i, flags := range allOpenFlags {
_, l, _, _, err := s.file.Create(fmt.Sprintf("test-%d", i), flags, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid()))
if err != nil {
@@ -296,7 +296,7 @@ func TestCreateSetGID(t *testing.T) {
t.Skipf("Test requires CAP_CHOWN")
}
- runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) {
+ runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s fileState) {
// Change group and set setgid to the parent dir.
if err := unix.Chown(s.file.hostPath, os.Getuid(), nobody); err != nil {
t.Fatalf("Chown() failed: %v", err)
@@ -364,7 +364,7 @@ func TestCreateSetGID(t *testing.T) {
// TestReadWriteDup tests that a file opened in any mode can be dup'ed and
// reopened in any other mode.
func TestReadWriteDup(t *testing.T) {
- runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) {
+ runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s fileState) {
child, err := createFile(s.file, "test")
if err != nil {
t.Fatalf("%v: createFile() failed, err: %v", s, err)
@@ -410,7 +410,7 @@ func TestReadWriteDup(t *testing.T) {
}
func TestUnopened(t *testing.T) {
- runCustom(t, []uint32{unix.S_IFREG}, allConfs, func(t *testing.T, s state) {
+ runCustom(t, []uint32{unix.S_IFREG}, allConfs, func(t *testing.T, s fileState) {
b := []byte("foobar")
if _, err := s.file.WriteAt(b, 0); err != unix.EBADF {
t.Errorf("%v: WriteAt() should have failed, got: %v, expected: unix.EBADF", s, err)
@@ -432,7 +432,7 @@ func TestUnopened(t *testing.T) {
// was open with O_PATH, but Open() was not checking for it and allowing the
// control file to be reused.
func TestOpenOPath(t *testing.T) {
- runCustom(t, []uint32{unix.S_IFREG}, rwConfs, func(t *testing.T, s state) {
+ runCustom(t, []uint32{unix.S_IFREG}, rwConfs, func(t *testing.T, s fileState) {
// Fist remove all permissions on the file.
if err := s.file.SetAttr(p9.SetAttrMask{Permissions: true}, p9.SetAttr{Permissions: p9.FileMode(0)}); err != nil {
t.Fatalf("SetAttr(): %v", err)
@@ -465,7 +465,7 @@ func SetGetAttr(l *localFile, valid p9.SetAttrMask, attr p9.SetAttr) (p9.Attr, e
}
func TestSetAttrPerm(t *testing.T) {
- runCustom(t, allTypes, rwConfs, func(t *testing.T, s state) {
+ runCustom(t, allTypes, rwConfs, func(t *testing.T, s fileState) {
valid := p9.SetAttrMask{Permissions: true}
attr := p9.SetAttr{Permissions: 0777}
got, err := SetGetAttr(s.file, valid, attr)
@@ -485,7 +485,7 @@ func TestSetAttrPerm(t *testing.T) {
}
func TestSetAttrSize(t *testing.T) {
- runCustom(t, allTypes, rwConfs, func(t *testing.T, s state) {
+ runCustom(t, allTypes, rwConfs, func(t *testing.T, s fileState) {
for _, size := range []uint64{1024, 0, 1024 * 1024} {
valid := p9.SetAttrMask{Size: true}
attr := p9.SetAttr{Size: size}
@@ -508,7 +508,7 @@ func TestSetAttrSize(t *testing.T) {
}
func TestSetAttrTime(t *testing.T) {
- runCustom(t, allTypes, rwConfs, func(t *testing.T, s state) {
+ runCustom(t, allTypes, rwConfs, func(t *testing.T, s fileState) {
valid := p9.SetAttrMask{ATime: true, ATimeNotSystemTime: true}
attr := p9.SetAttr{ATimeSeconds: 123, ATimeNanoSeconds: 456}
got, err := SetGetAttr(s.file, valid, attr)
@@ -542,7 +542,7 @@ func TestSetAttrOwner(t *testing.T) {
t.Skipf("SetAttr(owner) test requires CAP_CHOWN, running as %d", os.Getuid())
}
- runCustom(t, allTypes, rwConfs, func(t *testing.T, s state) {
+ runCustom(t, allTypes, rwConfs, func(t *testing.T, s fileState) {
newUID := os.Getuid() + 1
valid := p9.SetAttrMask{UID: true}
attr := p9.SetAttr{UID: p9.UID(newUID)}
@@ -571,7 +571,7 @@ func SetGetXattr(l *localFile, name string, value string) error {
}
func TestSetGetDisabledXattr(t *testing.T) {
- runCustom(t, []uint32{unix.S_IFREG}, rwConfs, func(t *testing.T, s state) {
+ runCustom(t, []uint32{unix.S_IFREG}, rwConfs, func(t *testing.T, s fileState) {
name := "user.merkle.offset"
value := "tmp"
err := SetGetXattr(s.file, name, value)
@@ -582,7 +582,7 @@ func TestSetGetDisabledXattr(t *testing.T) {
}
func TestSetGetXattr(t *testing.T) {
- runCustom(t, []uint32{unix.S_IFREG}, []Config{{ROMount: false, EnableVerityXattr: true}}, func(t *testing.T, s state) {
+ runCustom(t, []uint32{unix.S_IFREG}, []Config{{ROMount: false, EnableVerityXattr: true}}, func(t *testing.T, s fileState) {
name := "user.merkle.offset"
value := "tmp"
err := SetGetXattr(s.file, name, value)
@@ -596,7 +596,7 @@ func TestLink(t *testing.T) {
if !specutils.HasCapabilities(capability.CAP_DAC_READ_SEARCH) {
t.Skipf("Link test requires CAP_DAC_READ_SEARCH, running as %d", os.Getuid())
}
- runCustom(t, allTypes, rwConfs, func(t *testing.T, s state) {
+ runCustom(t, allTypes, rwConfs, func(t *testing.T, s fileState) {
const dirName = "linkdir"
const linkFile = "link"
if _, err := s.root.Mkdir(dirName, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil {
@@ -625,7 +625,7 @@ func TestROMountChecks(t *testing.T) {
uid := p9.UID(os.Getuid())
gid := p9.GID(os.Getgid())
- runCustom(t, allTypes, roConfs, func(t *testing.T, s state) {
+ runCustom(t, allTypes, roConfs, func(t *testing.T, s fileState) {
if s.fileType != unix.S_IFLNK {
if _, _, _, err := s.file.Open(p9.WriteOnly); err != want {
t.Errorf("Open() should have failed, got: %v, expected: %v", err, want)
@@ -676,7 +676,7 @@ func TestROMountChecks(t *testing.T) {
}
func TestWalkNotFound(t *testing.T) {
- runCustom(t, []uint32{unix.S_IFDIR}, allConfs, func(t *testing.T, s state) {
+ runCustom(t, []uint32{unix.S_IFDIR}, allConfs, func(t *testing.T, s fileState) {
if _, _, err := s.file.Walk([]string{"nobody-here"}); err != unix.ENOENT {
t.Errorf("Walk(%q) should have failed, got: %v, expected: unix.ENOENT", "nobody-here", err)
}
@@ -695,7 +695,7 @@ func TestWalkNotFound(t *testing.T) {
}
func TestWalkDup(t *testing.T) {
- runAll(t, func(t *testing.T, s state) {
+ runAll(t, func(t *testing.T, s fileState) {
_, dup, err := s.file.Walk([]string{})
if err != nil {
t.Fatalf("%v: Walk(nil) failed, err: %v", s, err)
@@ -708,7 +708,7 @@ func TestWalkDup(t *testing.T) {
}
func TestWalkMultiple(t *testing.T) {
- runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) {
+ runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s fileState) {
var names []string
var parent p9.File = s.file
for i := 0; i < 5; i++ {
@@ -729,7 +729,7 @@ func TestWalkMultiple(t *testing.T) {
}
func TestReaddir(t *testing.T) {
- runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) {
+ runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s fileState) {
name := "dir"
if _, err := s.file.Mkdir(name, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil {
t.Fatalf("%v: MkDir(%s) failed, err: %v", s, name, err)
@@ -915,7 +915,7 @@ func TestDoubleAttachError(t *testing.T) {
}
func TestTruncate(t *testing.T) {
- runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) {
+ runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s fileState) {
child, err := createFile(s.file, "test")
if err != nil {
t.Fatalf("createFile() failed: %v", err)
@@ -951,7 +951,7 @@ func TestTruncate(t *testing.T) {
}
func TestMknod(t *testing.T) {
- runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) {
+ runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s fileState) {
_, err := s.file.Mknod("test", p9.ModeRegular|0777, 1, 2, p9.UID(os.Getuid()), p9.GID(os.Getgid()))
if err != nil {
t.Fatalf("Mknod() failed: %v", err)
diff --git a/runsc/fsgofer/fsgofer_unsafe.go b/runsc/fsgofer/fsgofer_unsafe.go
index f11fea40d..fb4fbe0d2 100644
--- a/runsc/fsgofer/fsgofer_unsafe.go
+++ b/runsc/fsgofer/fsgofer_unsafe.go
@@ -21,6 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
)
+var unixDirentMaxSize uint32 = uint32(unsafe.Sizeof(unix.Dirent{}))
+
func utimensat(dirFd int, name string, times [2]unix.Timespec, flags int) error {
// utimensat(2) doesn't accept empty name, instead name must be nil to make it
// operate directly on 'dirFd' unlike other *at syscalls.
@@ -80,3 +82,31 @@ func renameat(oldDirFD int, oldName string, newDirFD int, newName string) error
}
return nil
}
+
+func parseDirents(buf []byte, handleDirent func(ino uint64, off int64, ftype uint8, name string) bool) {
+ for len(buf) > 0 {
+ // Interpret the buf populated by unix.Getdents as unix.Dirent.
+ dirent := *(*unix.Dirent)(unsafe.Pointer(&buf[0]))
+
+ // Extracting the name is pretty tedious...
+ var nameBuf [unix.NAME_MAX]byte
+ var nameLen int
+ for i := 0; i < len(dirent.Name); i++ {
+ // The name is null terminated.
+ if dirent.Name[i] == 0 {
+ nameLen = i
+ break
+ }
+ nameBuf[i] = byte(dirent.Name[i])
+ }
+ name := string(nameBuf[:nameLen])
+
+ // Deliver results to caller.
+ if !handleDirent(dirent.Ino, dirent.Off, dirent.Type, name) {
+ return
+ }
+
+ // Advance buf for the next dirent.
+ buf = buf[dirent.Reclen:]
+ }
+}
diff --git a/runsc/fsgofer/lisafs.go b/runsc/fsgofer/lisafs.go
new file mode 100644
index 000000000..0db44ff6a
--- /dev/null
+++ b/runsc/fsgofer/lisafs.go
@@ -0,0 +1,1084 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsgofer
+
+import (
+ "io"
+ "math"
+ "path"
+ "strconv"
+ "sync/atomic"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ rwfd "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/lisafs"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
+ "gvisor.dev/gvisor/pkg/p9"
+)
+
+// LisafsServer implements lisafs.ServerImpl for fsgofer.
+type LisafsServer struct {
+ lisafs.Server
+ config Config
+}
+
+var _ lisafs.ServerImpl = (*LisafsServer)(nil)
+
+// NewLisafsServer initializes a new lisafs server for fsgofer.
+func NewLisafsServer(config Config) *LisafsServer {
+ s := &LisafsServer{config: config}
+ s.Server.Init(s)
+ return s
+}
+
+// Mount implements lisafs.ServerImpl.Mount.
+func (s *LisafsServer) Mount(c *lisafs.Connection, mountPath string) (lisafs.ControlFDImpl, lisafs.Inode, error) {
+ s.RenameMu.RLock()
+ defer s.RenameMu.RUnlock()
+
+ rootFD, rootStat, err := tryStepLocked(c, mountPath, nil, func(flags int) (int, error) {
+ return unix.Open(mountPath, flags, 0)
+ })
+ if err != nil {
+ return nil, lisafs.Inode{}, err
+ }
+
+ var rootIno lisafs.Inode
+ rootFD.initInodeWithStat(&rootIno, &rootStat)
+ return rootFD, rootIno, nil
+}
+
+// MaxMessageSize implements lisafs.ServerImpl.MaxMessageSize.
+func (s *LisafsServer) MaxMessageSize() uint32 {
+ return lisafs.MaxMessageSize()
+}
+
+// SupportedMessages implements lisafs.ServerImpl.SupportedMessages.
+func (s *LisafsServer) SupportedMessages() []lisafs.MID {
+ // Note that Flush, FListXattr and FRemoveXattr are not supported.
+ return []lisafs.MID{
+ lisafs.Mount,
+ lisafs.Channel,
+ lisafs.FStat,
+ lisafs.SetStat,
+ lisafs.Walk,
+ lisafs.WalkStat,
+ lisafs.OpenAt,
+ lisafs.OpenCreateAt,
+ lisafs.Close,
+ lisafs.FSync,
+ lisafs.PWrite,
+ lisafs.PRead,
+ lisafs.MkdirAt,
+ lisafs.MknodAt,
+ lisafs.SymlinkAt,
+ lisafs.LinkAt,
+ lisafs.FStatFS,
+ lisafs.FAllocate,
+ lisafs.ReadLinkAt,
+ lisafs.Connect,
+ lisafs.UnlinkAt,
+ lisafs.RenameAt,
+ lisafs.Getdents64,
+ lisafs.FGetXattr,
+ lisafs.FSetXattr,
+ }
+}
+
+// controlFDLisa implements lisafs.ControlFDImpl.
+type controlFDLisa struct {
+ lisafs.ControlFD
+
+ // hostFD is the file descriptor which can be used to make host syscalls.
+ hostFD int
+
+ // writableHostFD is the file descriptor number for a writable FD opened on the
+ // same FD as `hostFD`. writableHostFD must only be accessed using atomic
+ // operations. It is initialized to -1, and can change in value exactly once.
+ writableHostFD int32
+}
+
+var _ lisafs.ControlFDImpl = (*controlFDLisa)(nil)
+
+// Precondition: server's rename mutex must be at least read locked.
+func newControlFDLisaLocked(c *lisafs.Connection, hostFD int, parent *controlFDLisa, name string, mode linux.FileMode) *controlFDLisa {
+ fd := &controlFDLisa{
+ hostFD: hostFD,
+ writableHostFD: -1,
+ }
+ fd.ControlFD.Init(c, parent.FD(), name, mode, fd)
+ return fd
+}
+
+func (fd *controlFDLisa) initInode(inode *lisafs.Inode) error {
+ inode.ControlFD = fd.ID()
+ return fstatTo(fd.hostFD, &inode.Stat)
+}
+
+func (fd *controlFDLisa) initInodeWithStat(inode *lisafs.Inode, unixStat *unix.Stat_t) {
+ inode.ControlFD = fd.ID()
+ unixToLinuxStat(unixStat, &inode.Stat)
+}
+
+func (fd *controlFDLisa) getWritableFD() (int, error) {
+ if writableFD := atomic.LoadInt32(&fd.writableHostFD); writableFD != -1 {
+ return int(writableFD), nil
+ }
+
+ writableFD, err := unix.Openat(int(procSelfFD.FD()), strconv.Itoa(fd.hostFD), (unix.O_WRONLY|openFlags)&^unix.O_NOFOLLOW, 0)
+ if err != nil {
+ return -1, err
+ }
+ if !atomic.CompareAndSwapInt32(&fd.writableHostFD, -1, int32(writableFD)) {
+ // Race detected, use the new value and clean this up.
+ unix.Close(writableFD)
+ return int(atomic.LoadInt32(&fd.writableHostFD)), nil
+ }
+ return writableFD, nil
+}
+
+// FD implements lisafs.ControlFDImpl.FD.
+func (fd *controlFDLisa) FD() *lisafs.ControlFD {
+ if fd == nil {
+ return nil
+ }
+ return &fd.ControlFD
+}
+
+// Close implements lisafs.ControlFDImpl.Close.
+func (fd *controlFDLisa) Close(c *lisafs.Connection) {
+ if fd.hostFD >= 0 {
+ _ = unix.Close(fd.hostFD)
+ fd.hostFD = -1
+ }
+ // No concurrent access is possible so no need to use atomics.
+ if fd.writableHostFD >= 0 {
+ _ = unix.Close(int(fd.writableHostFD))
+ fd.writableHostFD = -1
+ }
+}
+
+// Stat implements lisafs.ControlFDImpl.Stat.
+func (fd *controlFDLisa) Stat(c *lisafs.Connection, comm lisafs.Communicator) (uint32, error) {
+ var resp linux.Statx
+ if err := fstatTo(fd.hostFD, &resp); err != nil {
+ return 0, err
+ }
+
+ respLen := uint32(resp.SizeBytes())
+ resp.MarshalUnsafe(comm.PayloadBuf(respLen))
+ return respLen, nil
+}
+
+// SetStat implements lisafs.ControlFDImpl.SetStat.
+func (fd *controlFDLisa) SetStat(c *lisafs.Connection, comm lisafs.Communicator, stat lisafs.SetStatReq) (uint32, error) {
+ var resp lisafs.SetStatResp
+ if stat.Mask&unix.STATX_MODE != 0 {
+ if err := unix.Fchmod(fd.hostFD, stat.Mode&^unix.S_IFMT); err != nil {
+ log.Debugf("SetStat fchmod failed %q, err: %v", fd.FilePath(), err)
+ resp.FailureMask |= unix.STATX_MODE
+ resp.FailureErrNo = uint32(p9.ExtractErrno(err))
+ }
+ }
+
+ if stat.Mask&unix.STATX_SIZE != 0 {
+ // ftruncate(2) requires the FD to be open for writing.
+ writableFD, err := fd.getWritableFD()
+ if err == nil {
+ err = unix.Ftruncate(writableFD, int64(stat.Size))
+ }
+ if err != nil {
+ log.Debugf("SetStat ftruncate failed %q, err: %v", fd.FilePath(), err)
+ resp.FailureMask |= unix.STATX_SIZE
+ resp.FailureErrNo = uint32(p9.ExtractErrno(err))
+ }
+ }
+
+ if stat.Mask&(unix.STATX_ATIME|unix.STATX_MTIME) != 0 {
+ utimes := [2]unix.Timespec{
+ {Sec: 0, Nsec: unix.UTIME_OMIT},
+ {Sec: 0, Nsec: unix.UTIME_OMIT},
+ }
+ if stat.Mask&unix.STATX_ATIME != 0 {
+ utimes[0].Sec = stat.Atime.Sec
+ utimes[0].Nsec = stat.Atime.Nsec
+ }
+ if stat.Mask&unix.STATX_MTIME != 0 {
+ utimes[1].Sec = stat.Mtime.Sec
+ utimes[1].Nsec = stat.Mtime.Nsec
+ }
+
+ if fd.IsSymlink() {
+ // utimensat operates different that other syscalls. To operate on a
+ // symlink it *requires* AT_SYMLINK_NOFOLLOW with dirFD and a non-empty
+ // name.
+ c.Server().WithRenameReadLock(func() error {
+ if err := utimensat(fd.ParentLocked().(*controlFDLisa).hostFD, fd.NameLocked(), utimes, unix.AT_SYMLINK_NOFOLLOW); err != nil {
+ log.Debugf("SetStat utimens failed %q, err: %v", fd.FilePathLocked(), err)
+ resp.FailureMask |= (stat.Mask & (unix.STATX_ATIME | unix.STATX_MTIME))
+ resp.FailureErrNo = uint32(p9.ExtractErrno(err))
+ }
+ return nil
+ })
+ } else {
+ hostFD := fd.hostFD
+ if fd.IsRegular() {
+ // For regular files, utimensat(2) requires the FD to be open for
+ // writing, see BUGS section.
+ writableFD, err := fd.getWritableFD()
+ if err != nil {
+ return 0, err
+ }
+ hostFD = writableFD
+ }
+ // Directories and regular files can operate directly on the fd
+ // using empty name.
+ err := utimensat(hostFD, "", utimes, 0)
+ if err != nil {
+ log.Debugf("SetStat utimens failed %q, err: %v", fd.FilePath(), err)
+ resp.FailureMask |= (stat.Mask & (unix.STATX_ATIME | unix.STATX_MTIME))
+ resp.FailureErrNo = uint32(p9.ExtractErrno(err))
+ }
+ }
+ }
+
+ if stat.Mask&(unix.STATX_UID|unix.STATX_GID) != 0 {
+ // "If the owner or group is specified as -1, then that ID is not changed"
+ // - chown(2)
+ uid := -1
+ if stat.Mask&unix.STATX_UID != 0 {
+ uid = int(stat.UID)
+ }
+ gid := -1
+ if stat.Mask&unix.STATX_GID != 0 {
+ gid = int(stat.GID)
+ }
+ if err := unix.Fchownat(fd.hostFD, "", uid, gid, unix.AT_EMPTY_PATH|unix.AT_SYMLINK_NOFOLLOW); err != nil {
+ log.Debugf("SetStat fchown failed %q, err: %v", fd.FilePath(), err)
+ resp.FailureMask |= stat.Mask & (unix.STATX_UID | unix.STATX_GID)
+ resp.FailureErrNo = uint32(p9.ExtractErrno(err))
+ }
+ }
+
+ respLen := uint32(resp.SizeBytes())
+ resp.MarshalUnsafe(comm.PayloadBuf(respLen))
+ return respLen, nil
+}
+
+// Walk implements lisafs.ControlFDImpl.Walk.
+func (fd *controlFDLisa) Walk(c *lisafs.Connection, comm lisafs.Communicator, path lisafs.StringArray) (uint32, error) {
+ // We need to generate inodes for each component walked. We will manually
+ // marshal the inodes into the payload buffer as they are generated to avoid
+ // the slice allocation. The memory format should be lisafs.WalkResp's.
+ var numInodes primitive.Uint32
+ var status lisafs.WalkStatus
+ maxPayloadSize := status.SizeBytes() + numInodes.SizeBytes() + (len(path) * (*lisafs.Inode)(nil).SizeBytes())
+ if maxPayloadSize > math.MaxUint32 {
+ // Too much to walk, can't do.
+ return 0, unix.EIO
+ }
+ payloadBuf := comm.PayloadBuf(uint32(maxPayloadSize))
+ payloadPos := status.SizeBytes() + numInodes.SizeBytes()
+
+ s := c.Server()
+ s.RenameMu.RLock()
+ defer s.RenameMu.RUnlock()
+
+ curDirFD := fd
+ cu := cleanup.Make(func() {
+ // Destroy all newly created FDs until now. Walk upward from curDirFD to
+ // fd. Do not destroy fd as the client still owns that.
+ for curDirFD != fd {
+ c.RemoveControlFDLocked(curDirFD.ID())
+ curDirFD = curDirFD.ParentLocked().(*controlFDLisa)
+ }
+ })
+ defer cu.Clean()
+
+ for _, name := range path {
+ // Symlinks terminate walk. This client gets the symlink inode, but will
+ // have to invoke Walk again with the resolved path.
+ if curDirFD.IsSymlink() {
+ status = lisafs.WalkComponentSymlink
+ break
+ }
+
+ child, childStat, err := tryStepLocked(c, name, curDirFD, func(flags int) (int, error) {
+ return unix.Openat(curDirFD.hostFD, name, flags, 0)
+ })
+ if err == unix.ENOENT {
+ status = lisafs.WalkComponentDoesNotExist
+ break
+ }
+ if err != nil {
+ return 0, err
+ }
+
+ // Write inode to payloadBuf and update state.
+ var childInode lisafs.Inode
+ child.initInodeWithStat(&childInode, &childStat)
+ childInode.MarshalUnsafe(payloadBuf[payloadPos:])
+ payloadPos += childInode.SizeBytes()
+ numInodes++
+ curDirFD = child
+ }
+ cu.Release()
+
+ // lisafs.WalkResp writes the walk status followed by the number of inodes in
+ // the beginning.
+ status.MarshalUnsafe(payloadBuf)
+ numInodes.MarshalUnsafe(payloadBuf[status.SizeBytes():])
+ return uint32(payloadPos), nil
+}
+
+// WalkStat implements lisafs.ControlFDImpl.WalkStat.
+func (fd *controlFDLisa) WalkStat(c *lisafs.Connection, comm lisafs.Communicator, path lisafs.StringArray) (uint32, error) {
+ // We may need to generate statx for dirFD + each component walked. We will
+ // manually marshal the statx results into the payload buffer as they are
+ // generated to avoid the slice allocation. The memory format should be the
+ // same as lisafs.WalkStatResp's.
+ var numStats primitive.Uint32
+ maxPayloadSize := numStats.SizeBytes() + (len(path) * linux.SizeOfStatx)
+ if maxPayloadSize > math.MaxUint32 {
+ // Too much to walk, can't do.
+ return 0, unix.EIO
+ }
+ payloadBuf := comm.PayloadBuf(uint32(maxPayloadSize))
+ payloadPos := numStats.SizeBytes()
+
+ s := c.Server()
+ s.RenameMu.RLock()
+ defer s.RenameMu.RUnlock()
+
+ curDirFD := fd.hostFD
+ closeCurDirFD := func() {
+ if curDirFD != fd.hostFD {
+ unix.Close(curDirFD)
+ }
+ }
+ defer closeCurDirFD()
+ var (
+ stat linux.Statx
+ unixStat unix.Stat_t
+ )
+ if len(path) > 0 && len(path[0]) == 0 {
+ // Write stat results for dirFD if the first path component is "".
+ if err := unix.Fstat(fd.hostFD, &unixStat); err != nil {
+ return 0, err
+ }
+ unixToLinuxStat(&unixStat, &stat)
+ stat.MarshalUnsafe(payloadBuf[payloadPos:])
+ payloadPos += stat.SizeBytes()
+ path = path[1:]
+ numStats++
+ }
+
+ // Don't attempt walking if parent is a symlink.
+ if fd.IsSymlink() {
+ return 0, nil
+ }
+ for _, name := range path {
+ curFD, err := unix.Openat(curDirFD, name, unix.O_PATH|openFlags, 0)
+ if err == unix.ENOENT {
+ // No more path components exist on the filesystem. Return the partial
+ // walk to the client.
+ break
+ }
+ if err != nil {
+ return 0, err
+ }
+ closeCurDirFD()
+ curDirFD = curFD
+
+ // Write stat results for curFD.
+ if err := unix.Fstat(curFD, &unixStat); err != nil {
+ return 0, err
+ }
+ unixToLinuxStat(&unixStat, &stat)
+ stat.MarshalUnsafe(payloadBuf[payloadPos:])
+ payloadPos += stat.SizeBytes()
+ numStats++
+
+ // Symlinks terminate walk. This client gets the symlink stat result, but
+ // will have to invoke Walk again with the resolved path.
+ if unixStat.Mode&unix.S_IFMT == unix.S_IFLNK {
+ break
+ }
+ }
+
+ // lisafs.WalkStatResp writes the number of stats in the beginning.
+ numStats.MarshalUnsafe(payloadBuf)
+ return uint32(payloadPos), nil
+}
+
+// Open implements lisafs.ControlFDImpl.Open.
+func (fd *controlFDLisa) Open(c *lisafs.Connection, comm lisafs.Communicator, flags uint32) (uint32, error) {
+ flags |= openFlags
+ newHostFD, err := unix.Openat(int(procSelfFD.FD()), strconv.Itoa(fd.hostFD), int(flags)&^unix.O_NOFOLLOW, 0)
+ if err != nil {
+ return 0, err
+ }
+ newFD := fd.newOpenFDLisa(newHostFD, flags)
+
+ if fd.IsRegular() {
+ // Donate FD for regular files only. Since FD donation is a destructive
+ // operation, we should duplicate the to-be-donated FD. Eat the error if
+ // one occurs, it is better to have an FD without a host FD, than failing
+ // the Open attempt.
+ if dupFD, err := unix.Dup(newFD.hostFD); err == nil {
+ _ = comm.DonateFD(dupFD)
+ }
+ }
+
+ resp := lisafs.OpenAtResp{NewFD: newFD.ID()}
+ respLen := uint32(resp.SizeBytes())
+ resp.MarshalUnsafe(comm.PayloadBuf(respLen))
+ return respLen, nil
+}
+
+// OpenCreate implements lisafs.ControlFDImpl.OpenCreate.
+func (fd *controlFDLisa) OpenCreate(c *lisafs.Connection, comm lisafs.Communicator, mode linux.FileMode, uid lisafs.UID, gid lisafs.GID, name string, flags uint32) (uint32, error) {
+ // Need to hold rename mutex for reading while performing the walk. Also keep
+ // holding it while the cleanup is still possible.
+ var resp lisafs.OpenCreateAtResp
+ var newFD *openFDLisa
+ if err := c.Server().WithRenameReadLock(func() error {
+ createFlags := unix.O_CREAT | unix.O_EXCL | unix.O_RDONLY | unix.O_NONBLOCK | openFlags
+ childHostFD, err := unix.Openat(fd.hostFD, name, createFlags, uint32(mode&^linux.FileTypeMask))
+ if err != nil {
+ return err
+ }
+
+ childFD := newControlFDLisaLocked(c, childHostFD, fd, name, linux.ModeRegular)
+ cu := cleanup.Make(func() {
+ // Best effort attempt to remove the file in case of failure.
+ if err := unix.Unlinkat(fd.hostFD, name, 0); err != nil {
+ log.Warningf("error unlinking file %q after failure: %v", path.Join(fd.FilePathLocked(), name), err)
+ }
+ c.RemoveControlFDLocked(childFD.ID())
+ })
+ defer cu.Clean()
+
+ // Set the owners as requested by the client.
+ if err := unix.Fchownat(childFD.hostFD, "", int(uid), int(gid), unix.AT_EMPTY_PATH|unix.AT_SYMLINK_NOFOLLOW); err != nil {
+ log.Infof("ayush: Fchownat %v", err)
+ return err
+ }
+
+ // Do not use the stat result from tryOpen because the owners might have
+ // changed. initInode() will stat the FD again and use fresh results.
+ if err := childFD.initInode(&resp.Child); err != nil {
+ log.Infof("ayush: initInode %v", err)
+ return err
+ }
+
+ // Now open an FD to the newly created file with the flags requested by the client.
+ flags |= openFlags
+ newHostFD, err := unix.Openat(int(procSelfFD.FD()), strconv.Itoa(childFD.hostFD), int(flags)&^unix.O_NOFOLLOW, 0)
+ if err != nil {
+ log.Infof("ayush: Openat %v", err)
+ return err
+ }
+ cu.Release()
+
+ newFD = childFD.newOpenFDLisa(newHostFD, uint32(flags))
+ resp.NewFD = newFD.ID()
+ return nil
+ }); err != nil {
+ return 0, err
+ }
+
+ // Donate FD because open(O_CREAT|O_EXCL) always creates a regular file.
+ // Since FD donation is a destructive operation, we should duplicate the
+ // to-be-donated FD. Eat the error if one occurs, it is better to have an FD
+ // without a host FD, than failing the Open attempt.
+ if dupFD, err := unix.Dup(newFD.hostFD); err == nil {
+ _ = comm.DonateFD(dupFD)
+ }
+
+ respLen := uint32(resp.SizeBytes())
+ resp.MarshalUnsafe(comm.PayloadBuf(respLen))
+ return respLen, nil
+}
+
+// Mkdir implements lisafs.ControlFDImpl.Mkdir.
+func (fd *controlFDLisa) Mkdir(c *lisafs.Connection, comm lisafs.Communicator, mode linux.FileMode, uid lisafs.UID, gid lisafs.GID, name string) (uint32, error) {
+ var resp lisafs.MkdirAtResp
+ if err := c.Server().WithRenameReadLock(func() error {
+ if err := unix.Mkdirat(fd.hostFD, name, uint32(mode&^linux.FileTypeMask)); err != nil {
+ return err
+ }
+ cu := cleanup.Make(func() {
+ // Best effort attempt to remove the dir in case of failure.
+ if err := unix.Unlinkat(fd.hostFD, name, unix.AT_REMOVEDIR); err != nil {
+ log.Warningf("error unlinking dir %q after failure: %v", path.Join(fd.FilePathLocked(), name), err)
+ }
+ })
+ defer cu.Clean()
+
+ // Open directory to change ownership.
+ childDirFd, err := unix.Openat(fd.hostFD, name, unix.O_DIRECTORY|unix.O_RDONLY|openFlags, 0)
+ if err != nil {
+ return err
+ }
+ if err := unix.Fchownat(childDirFd, "", int(uid), int(gid), unix.AT_EMPTY_PATH|unix.AT_SYMLINK_NOFOLLOW); err != nil {
+ unix.Close(childDirFd)
+ return err
+ }
+
+ childDir := newControlFDLisaLocked(c, childDirFd, fd, name, linux.ModeDirectory)
+ if err := childDir.initInode(&resp.ChildDir); err != nil {
+ c.RemoveControlFDLocked(childDir.ID())
+ return err
+ }
+ cu.Release()
+
+ return nil
+ }); err != nil {
+ return 0, err
+ }
+
+ respLen := uint32(resp.SizeBytes())
+ resp.MarshalUnsafe(comm.PayloadBuf(respLen))
+ return respLen, nil
+}
+
+// Mknod implements lisafs.ControlFDImpl.Mknod.
+func (fd *controlFDLisa) Mknod(c *lisafs.Connection, comm lisafs.Communicator, mode linux.FileMode, uid lisafs.UID, gid lisafs.GID, name string, minor uint32, major uint32) (uint32, error) {
+ // From mknod(2) man page:
+ // "EPERM: [...] if the filesystem containing pathname does not support
+ // the type of node requested."
+ if mode.FileType() != linux.ModeRegular {
+ return 0, unix.EPERM
+ }
+
+ var resp lisafs.MknodAtResp
+ if err := c.Server().WithRenameReadLock(func() error {
+ if err := unix.Mknodat(fd.hostFD, name, uint32(mode), 0); err != nil {
+ return err
+ }
+ cu := cleanup.Make(func() {
+ // Best effort attempt to remove the file in case of failure.
+ if err := unix.Unlinkat(fd.hostFD, name, 0); err != nil {
+ log.Warningf("error unlinking file %q after failure: %v", path.Join(fd.FilePathLocked(), name), err)
+ }
+ })
+ defer cu.Clean()
+
+ // Open file to change ownership.
+ childFD, err := unix.Openat(fd.hostFD, name, unix.O_PATH|openFlags, 0)
+ if err != nil {
+ return err
+ }
+ if err := unix.Fchownat(childFD, "", int(uid), int(gid), unix.AT_EMPTY_PATH|unix.AT_SYMLINK_NOFOLLOW); err != nil {
+ unix.Close(childFD)
+ return err
+ }
+
+ child := newControlFDLisaLocked(c, childFD, fd, name, mode)
+ if err := child.initInode(&resp.Child); err != nil {
+ c.RemoveControlFDLocked(child.ID())
+ return err
+ }
+ cu.Release()
+ return nil
+ }); err != nil {
+ return 0, err
+ }
+
+ respLen := uint32(resp.SizeBytes())
+ resp.MarshalUnsafe(comm.PayloadBuf(respLen))
+ return respLen, nil
+}
+
+// Symlink implements lisafs.ControlFDImpl.Symlink.
+func (fd *controlFDLisa) Symlink(c *lisafs.Connection, comm lisafs.Communicator, name string, target string, uid lisafs.UID, gid lisafs.GID) (uint32, error) {
+ var resp lisafs.SymlinkAtResp
+ if err := c.Server().WithRenameReadLock(func() error {
+ if err := unix.Symlinkat(target, fd.hostFD, name); err != nil {
+ return err
+ }
+ cu := cleanup.Make(func() {
+ // Best effort attempt to remove the symlink in case of failure.
+ if err := unix.Unlinkat(fd.hostFD, name, 0); err != nil {
+ log.Warningf("error unlinking file %q after failure: %v", path.Join(fd.FilePathLocked(), name), err)
+ }
+ })
+ defer cu.Clean()
+
+ // Open symlink to change ownership.
+ symlinkFD, err := unix.Openat(fd.hostFD, name, unix.O_PATH|openFlags, 0)
+ if err != nil {
+ return err
+ }
+ if err := unix.Fchownat(symlinkFD, "", int(uid), int(gid), unix.AT_EMPTY_PATH|unix.AT_SYMLINK_NOFOLLOW); err != nil {
+ unix.Close(symlinkFD)
+ return err
+ }
+
+ symlink := newControlFDLisaLocked(c, symlinkFD, fd, name, linux.ModeSymlink)
+ if err := symlink.initInode(&resp.Symlink); err != nil {
+ c.RemoveControlFDLocked(symlink.ID())
+ return err
+ }
+ cu.Release()
+ return nil
+ }); err != nil {
+ return 0, err
+ }
+
+ respLen := uint32(resp.SizeBytes())
+ resp.MarshalUnsafe(comm.PayloadBuf(respLen))
+ return respLen, nil
+}
+
+// Link implements lisafs.ControlFDImpl.Link.
+func (fd *controlFDLisa) Link(c *lisafs.Connection, comm lisafs.Communicator, dir lisafs.ControlFDImpl, name string) (uint32, error) {
+ var resp lisafs.LinkAtResp
+ if err := c.Server().WithRenameReadLock(func() error {
+ dirFD := dir.(*controlFDLisa)
+ if err := unix.Linkat(fd.hostFD, "", dirFD.hostFD, name, unix.AT_EMPTY_PATH); err != nil {
+ return err
+ }
+ cu := cleanup.Make(func() {
+ // Best effort attempt to remove the hard link in case of failure.
+ if err := unix.Unlinkat(dirFD.hostFD, name, 0); err != nil {
+ log.Warningf("error unlinking file %q after failure: %v", path.Join(dirFD.FilePathLocked(), name), err)
+ }
+ })
+ defer cu.Clean()
+
+ linkFD, linkStat, err := tryStepLocked(c, name, dirFD, func(flags int) (int, error) {
+ return unix.Openat(dirFD.hostFD, name, flags, 0)
+ })
+ if err != nil {
+ return err
+ }
+ cu.Release()
+
+ linkFD.initInodeWithStat(&resp.Link, &linkStat)
+ return nil
+ }); err != nil {
+ return 0, err
+ }
+
+ respLen := uint32(resp.SizeBytes())
+ resp.MarshalUnsafe(comm.PayloadBuf(respLen))
+ return respLen, nil
+}
+
+// StatFS implements lisafs.ControlFDImpl.StatFS.
+func (fd *controlFDLisa) StatFS(c *lisafs.Connection, comm lisafs.Communicator) (uint32, error) {
+ var s unix.Statfs_t
+ if err := unix.Fstatfs(fd.hostFD, &s); err != nil {
+ return 0, err
+ }
+
+ resp := lisafs.StatFS{
+ Type: uint64(s.Type),
+ BlockSize: s.Bsize,
+ Blocks: s.Blocks,
+ BlocksFree: s.Bfree,
+ BlocksAvailable: s.Bavail,
+ Files: s.Files,
+ FilesFree: s.Ffree,
+ NameLength: uint64(s.Namelen),
+ }
+ respLen := uint32(resp.SizeBytes())
+ resp.MarshalUnsafe(comm.PayloadBuf(respLen))
+ return respLen, nil
+}
+
+// Readlink implements lisafs.ControlFDImpl.Readlink.
+func (fd *controlFDLisa) Readlink(c *lisafs.Connection, comm lisafs.Communicator) (uint32, error) {
+ // We will manually marshal lisafs.ReadLinkAtResp, which just contains a
+ // lisafs.SizedString. Let unix.Readlinkat directly write into the payload
+ // buffer and manually write the string size before it.
+
+ // This is similar to what os.Readlink does.
+ const limit = primitive.Uint32(1024 * 1024)
+ for linkLen := primitive.Uint32(128); linkLen < limit; linkLen *= 2 {
+ b := comm.PayloadBuf(uint32(linkLen) + uint32(linkLen.SizeBytes()))
+ n, err := unix.Readlinkat(fd.hostFD, "", b[linkLen.SizeBytes():])
+ if err != nil {
+ return 0, err
+ }
+ if n < int(linkLen) {
+ linkLen = primitive.Uint32(n)
+ linkLen.MarshalUnsafe(b[:linkLen.SizeBytes()])
+ return uint32(linkLen) + uint32(linkLen.SizeBytes()), nil
+ }
+ }
+ return 0, unix.ENOMEM
+}
+
+// Connect implements lisafs.ControlFDImpl.Connect.
+func (fd *controlFDLisa) Connect(c *lisafs.Connection, comm lisafs.Communicator, sockType uint32) error {
+ s := c.ServerImpl().(*LisafsServer)
+ if !s.config.HostUDS {
+ return unix.ECONNREFUSED
+ }
+
+ // Lock RenameMu so that the hostPath read stays valid and is not tampered
+ // with until it is actually connected to.
+ s.RenameMu.RLock()
+ defer s.RenameMu.RUnlock()
+
+ // TODO(gvisor.dev/issue/1003): Due to different app vs replacement
+ // mappings, the app path may have fit in the sockaddr, but we can't fit
+ // hostPath in our sockaddr. We'd need to redirect through a shorter path
+ // in order to actually connect to this socket.
+ hostPath := fd.FilePathLocked()
+ if len(hostPath) > 108 { // UNIX_PATH_MAX = 108 is defined in afunix.h.
+ return unix.ECONNREFUSED
+ }
+
+ // Only the following types are supported.
+ switch sockType {
+ case unix.SOCK_STREAM, unix.SOCK_DGRAM, unix.SOCK_SEQPACKET:
+ default:
+ return unix.ENXIO
+ }
+
+ sock, err := unix.Socket(unix.AF_UNIX, int(sockType), 0)
+ if err != nil {
+ return err
+ }
+ if err := comm.DonateFD(sock); err != nil {
+ return err
+ }
+
+ sa := unix.SockaddrUnix{Name: hostPath}
+ if err := unix.Connect(sock, &sa); err != nil {
+ return err
+ }
+ return nil
+}
+
+// Unlink implements lisafs.ControlFDImpl.Unlink.
+func (fd *controlFDLisa) Unlink(c *lisafs.Connection, name string, flags uint32) error {
+ return c.Server().WithRenameReadLock(func() error {
+ return unix.Unlinkat(fd.hostFD, name, int(flags))
+ })
+}
+
+// RenameLocked implements lisafs.ControlFDImpl.RenameLocked.
+func (fd *controlFDLisa) RenameLocked(c *lisafs.Connection, newDir lisafs.ControlFDImpl, newName string) (func(lisafs.ControlFDImpl), func(), error) {
+ // Note that there is no controlFDLisa specific update needed on rename.
+ return nil, nil, renameat(fd.ParentLocked().(*controlFDLisa).hostFD, fd.NameLocked(), newDir.(*controlFDLisa).hostFD, newName)
+}
+
+// GetXattr implements lisafs.ControlFDImpl.GetXattr.
+func (fd *controlFDLisa) GetXattr(c *lisafs.Connection, comm lisafs.Communicator, name string, size uint32) (uint32, error) {
+ if !c.ServerImpl().(*LisafsServer).config.EnableVerityXattr {
+ return 0, unix.EOPNOTSUPP
+ }
+ if _, ok := verityXattrs[name]; !ok {
+ return 0, unix.EOPNOTSUPP
+ }
+
+ // Manually marshal lisafs.FGetXattrResp to avoid allocations and copying.
+ var valueLen primitive.Uint32
+ buf := comm.PayloadBuf(uint32(valueLen.SizeBytes()) + size)
+ n, err := unix.Fgetxattr(fd.hostFD, name, buf[valueLen.SizeBytes():])
+ if err != nil {
+ return 0, err
+ }
+ valueLen = primitive.Uint32(n)
+ valueLen.MarshalBytes(buf[:valueLen.SizeBytes()])
+
+ return uint32(valueLen.SizeBytes() + n), nil
+}
+
+// SetXattr implements lisafs.ControlFDImpl.SetXattr.
+func (fd *controlFDLisa) SetXattr(c *lisafs.Connection, name string, value string, flags uint32) error {
+ if !c.ServerImpl().(*LisafsServer).config.EnableVerityXattr {
+ return unix.EOPNOTSUPP
+ }
+ if _, ok := verityXattrs[name]; !ok {
+ return unix.EOPNOTSUPP
+ }
+ return unix.Fsetxattr(fd.hostFD, name, []byte(value) /* sigh */, int(flags))
+}
+
+// ListXattr implements lisafs.ControlFDImpl.ListXattr.
+func (fd *controlFDLisa) ListXattr(c *lisafs.Connection, comm lisafs.Communicator, size uint64) (uint32, error) {
+ return 0, unix.EOPNOTSUPP
+}
+
+// RemoveXattr implements lisafs.ControlFDImpl.RemoveXattr.
+func (fd *controlFDLisa) RemoveXattr(c *lisafs.Connection, comm lisafs.Communicator, name string) error {
+ return unix.EOPNOTSUPP
+}
+
+// openFDLisa implements lisafs.OpenFDImpl.
+type openFDLisa struct {
+ lisafs.OpenFD
+
+ // hostFD is the host file descriptor which can be used to make syscalls.
+ hostFD int
+}
+
+var _ lisafs.OpenFDImpl = (*openFDLisa)(nil)
+
+func (fd *controlFDLisa) newOpenFDLisa(hostFD int, flags uint32) *openFDLisa {
+ newFD := &openFDLisa{
+ hostFD: hostFD,
+ }
+ newFD.OpenFD.Init(fd.FD(), flags, newFD)
+ return newFD
+}
+
+// FD implements lisafs.OpenFDImpl.FD.
+func (fd *openFDLisa) FD() *lisafs.OpenFD {
+ if fd == nil {
+ return nil
+ }
+ return &fd.OpenFD
+}
+
+// Close implements lisafs.OpenFDImpl.Close.
+func (fd *openFDLisa) Close(c *lisafs.Connection) {
+ if fd.hostFD >= 0 {
+ _ = unix.Close(fd.hostFD)
+ fd.hostFD = -1
+ }
+}
+
+// Stat implements lisafs.OpenFDImpl.Stat.
+func (fd *openFDLisa) Stat(c *lisafs.Connection, comm lisafs.Communicator) (uint32, error) {
+ var resp linux.Statx
+ if err := fstatTo(fd.hostFD, &resp); err != nil {
+ return 0, err
+ }
+
+ respLen := uint32(resp.SizeBytes())
+ resp.MarshalUnsafe(comm.PayloadBuf(respLen))
+ return respLen, nil
+}
+
+// Sync implements lisafs.OpenFDImpl.Sync.
+func (fd *openFDLisa) Sync(c *lisafs.Connection) error {
+ return unix.Fsync(fd.hostFD)
+}
+
+// Write implements lisafs.OpenFDImpl.Write.
+func (fd *openFDLisa) Write(c *lisafs.Connection, comm lisafs.Communicator, buf []byte, off uint64) (uint32, error) {
+ rw := rwfd.NewReadWriter(fd.hostFD)
+ n, err := rw.WriteAt(buf, int64(off))
+ if err != nil {
+ return 0, err
+ }
+
+ resp := &lisafs.PWriteResp{Count: uint64(n)}
+ respLen := uint32(resp.SizeBytes())
+ resp.MarshalUnsafe(comm.PayloadBuf(respLen))
+ return respLen, nil
+}
+
+// Read implements lisafs.OpenFDImpl.Read.
+func (fd *openFDLisa) Read(c *lisafs.Connection, comm lisafs.Communicator, off uint64, count uint32) (uint32, error) {
+ // To save an allocation and a copy, we directly read into the payload
+ // buffer. The rest of the response message is manually marshalled.
+ var resp lisafs.PReadResp
+ respMetaSize := uint32(resp.NumBytes.SizeBytes())
+ maxRespLen := respMetaSize + count
+
+ payloadBuf := comm.PayloadBuf(maxRespLen)
+ rw := rwfd.NewReadWriter(fd.hostFD)
+ n, err := rw.ReadAt(payloadBuf[respMetaSize:], int64(off))
+ if err != nil && err != io.EOF {
+ return 0, err
+ }
+
+ // Write the response metadata onto the payload buffer. The response contents
+ // already have been written immediately after it.
+ resp.NumBytes = primitive.Uint32(n)
+ resp.NumBytes.MarshalUnsafe(payloadBuf[:respMetaSize])
+ return respMetaSize + uint32(n), nil
+}
+
+// Allocate implements lisafs.OpenFDImpl.Allocate.
+func (fd *openFDLisa) Allocate(c *lisafs.Connection, mode, off, length uint64) error {
+ return unix.Fallocate(fd.hostFD, uint32(mode), int64(off), int64(length))
+}
+
+// Flush implements lisafs.OpenFDImpl.Flush.
+func (fd *openFDLisa) Flush(c *lisafs.Connection) error {
+ return nil
+}
+
+// Getdent64 implements lisafs.OpenFDImpl.Getdent64.
+func (fd *openFDLisa) Getdent64(c *lisafs.Connection, comm lisafs.Communicator, count uint32, seek0 bool) (uint32, error) {
+ if seek0 {
+ if _, err := unix.Seek(fd.hostFD, 0, 0); err != nil {
+ return 0, err
+ }
+ }
+
+ // We will manually marshal the response lisafs.Getdents64Resp.
+
+ // numDirents is the number of dirents marshalled into the payload.
+ var numDirents primitive.Uint32
+ // The payload starts with numDirents, dirents go right after that.
+ // payloadBufPos represents the position at which to write the next dirent.
+ payloadBufPos := uint32(numDirents.SizeBytes())
+ // Request enough payloadBuf for 10 dirents, we will extend when needed.
+ payloadBuf := comm.PayloadBuf(payloadBufPos + 10*unixDirentMaxSize)
+
+ var direntsBuf [8192]byte
+ var bytesRead int
+ for bytesRead < int(count) {
+ bufEnd := len(direntsBuf)
+ if remaining := int(count) - bytesRead; remaining < bufEnd {
+ bufEnd = remaining
+ }
+ n, err := unix.Getdents(fd.hostFD, direntsBuf[:bufEnd])
+ if err != nil {
+ if err == unix.EINVAL && bufEnd < 268 {
+ // getdents64(2) returns EINVAL is returned when the result
+ // buffer is too small. If bufEnd is smaller than the max
+ // size of unix.Dirent, then just break here to return all
+ // dirents collected till now.
+ break
+ }
+ return 0, err
+ }
+ if n <= 0 {
+ break
+ }
+ bytesRead += n
+
+ var statErr error
+ parseDirents(direntsBuf[:n], func(ino uint64, off int64, ftype uint8, name string) bool {
+ dirent := lisafs.Dirent64{
+ Ino: primitive.Uint64(ino),
+ Off: primitive.Uint64(off),
+ Type: primitive.Uint8(ftype),
+ Name: lisafs.SizedString(name),
+ }
+
+ // The client also wants the device ID, which annoyingly incurs an
+ // additional syscall per dirent. Live with it.
+ stat, err := statAt(fd.hostFD, name)
+ if err != nil {
+ statErr = err
+ return false
+ }
+ dirent.DevMinor = primitive.Uint32(unix.Minor(stat.Dev))
+ dirent.DevMajor = primitive.Uint32(unix.Major(stat.Dev))
+
+ // Paste the dirent into the payload buffer without having the dirent
+ // escape. Request a larger buffer if needed.
+ if int(payloadBufPos)+dirent.SizeBytes() > len(payloadBuf) {
+ // Ask for 10 large dirents worth of more space.
+ payloadBuf = comm.PayloadBuf(payloadBufPos + 10*unixDirentMaxSize)
+ }
+ dirent.MarshalBytes(payloadBuf[payloadBufPos:])
+ payloadBufPos += uint32(dirent.SizeBytes())
+ numDirents++
+ return true
+ })
+ if statErr != nil {
+ return 0, statErr
+ }
+ }
+
+ // The number of dirents goes at the beginning of the payload.
+ numDirents.MarshalUnsafe(payloadBuf)
+ return payloadBufPos, nil
+}
+
+// tryStepLocked tries to walk via open() with different modes as documented.
+// It then initializes and returns the control FD.
+//
+// Precondition: server's rename mutex must at least be read locked.
+func tryStepLocked(c *lisafs.Connection, name string, parent *controlFDLisa, open func(flags int) (int, error)) (*controlFDLisa, unix.Stat_t, error) {
+ // Attempt to open file in the following in order:
+ // 1. RDONLY | NONBLOCK: for all files, directories, ro mounts, FIFOs.
+ // Use non-blocking to prevent getting stuck inside open(2) for
+ // FIFOs. This option has no effect on regular files.
+ // 2. PATH: for symlinks, sockets.
+ options := []struct {
+ flag int
+ readable bool
+ }{
+ {
+ flag: unix.O_RDONLY | unix.O_NONBLOCK,
+ readable: true,
+ },
+ {
+ flag: unix.O_PATH,
+ readable: false,
+ },
+ }
+
+ for i, option := range options {
+ hostFD, err := open(option.flag | openFlags)
+ if err == nil {
+ var stat unix.Stat_t
+ if err = unix.Fstat(hostFD, &stat); err == nil {
+ return newControlFDLisaLocked(c, hostFD, parent, name, linux.FileMode(stat.Mode)), stat, nil
+ }
+ unix.Close(hostFD)
+ }
+
+ e := extractErrno(err)
+ if e == unix.ENOENT {
+ // File doesn't exist, no point in retrying.
+ return nil, unix.Stat_t{}, e
+ }
+ if i < len(options)-1 {
+ continue
+ }
+ return nil, unix.Stat_t{}, e
+ }
+ panic("unreachable")
+}
+
+func fstatTo(hostFD int, stat *linux.Statx) error {
+ var unixStat unix.Stat_t
+ if err := unix.Fstat(hostFD, &unixStat); err != nil {
+ return err
+ }
+
+ unixToLinuxStat(&unixStat, stat)
+ return nil
+}
+
+func unixToLinuxStat(from *unix.Stat_t, to *linux.Statx) {
+ to.Mask = unix.STATX_TYPE | unix.STATX_MODE | unix.STATX_INO | unix.STATX_NLINK | unix.STATX_UID | unix.STATX_GID | unix.STATX_SIZE | unix.STATX_BLOCKS | unix.STATX_ATIME | unix.STATX_MTIME | unix.STATX_CTIME
+ to.Mode = uint16(from.Mode)
+ to.DevMinor = unix.Minor(from.Dev)
+ to.DevMajor = unix.Major(from.Dev)
+ to.Ino = from.Ino
+ to.Nlink = uint32(from.Nlink)
+ to.UID = from.Uid
+ to.GID = from.Gid
+ to.RdevMinor = unix.Minor(from.Rdev)
+ to.RdevMajor = unix.Major(from.Rdev)
+ to.Size = uint64(from.Size)
+ to.Blksize = uint32(from.Blksize)
+ to.Blocks = uint64(from.Blocks)
+ to.Atime.Sec = from.Atim.Sec
+ to.Atime.Nsec = uint32(from.Atim.Nsec)
+ to.Mtime.Sec = from.Mtim.Sec
+ to.Mtime.Nsec = uint32(from.Mtim.Nsec)
+ to.Ctime.Sec = from.Ctim.Sec
+ to.Ctime.Nsec = uint32(from.Ctim.Nsec)
+}
diff --git a/runsc/fsgofer/lisafs_test.go b/runsc/fsgofer/lisafs_test.go
new file mode 100644
index 000000000..4653f9955
--- /dev/null
+++ b/runsc/fsgofer/lisafs_test.go
@@ -0,0 +1,56 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lisafs_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/lisafs"
+ "gvisor.dev/gvisor/pkg/lisafs/testsuite"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/runsc/fsgofer"
+)
+
+// Note that these are not supposed to be extensive or robust tests. These unit
+// tests provide a sanity check that all RPCs at least work in obvious ways.
+
+func init() {
+ log.SetLevel(log.Debug)
+ if err := fsgofer.OpenProcSelfFD(); err != nil {
+ panic(err)
+ }
+}
+
+// tester implements testsuite.Tester.
+type tester struct{}
+
+// NewServer implements testsuite.Tester.NewServer.
+func (tester) NewServer(t *testing.T) *lisafs.Server {
+ return &fsgofer.NewLisafsServer(fsgofer.Config{HostUDS: true, EnableVerityXattr: true}).Server
+}
+
+// LinkSupported implements testsuite.Tester.LinkSupported.
+func (tester) LinkSupported() bool {
+ return true
+}
+
+// SetUserGroupIDSupported implements testsuite.Tester.SetUserGroupIDSupported.
+func (tester) SetUserGroupIDSupported() bool {
+ return true
+}
+
+func TestFSGofer(t *testing.T) {
+ testsuite.RunAllLocalFSTests(t, tester{})
+}
diff --git a/runsc/sandbox/BUILD b/runsc/sandbox/BUILD
index d625230dd..bca14c7b8 100644
--- a/runsc/sandbox/BUILD
+++ b/runsc/sandbox/BUILD
@@ -1,10 +1,11 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "sandbox",
srcs = [
+ "memory.go",
"network.go",
"network_unsafe.go",
"sandbox.go",
@@ -39,3 +40,10 @@ go_library(
"@org_golang_x_sys//unix:go_default_library",
],
)
+
+go_test(
+ name = "sandbox_test",
+ size = "small",
+ srcs = ["memory_test.go"],
+ library = ":sandbox",
+)
diff --git a/runsc/sandbox/memory.go b/runsc/sandbox/memory.go
new file mode 100644
index 000000000..77b88eb89
--- /dev/null
+++ b/runsc/sandbox/memory.go
@@ -0,0 +1,70 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sandbox
+
+import (
+ "bufio"
+ "fmt"
+ "io"
+ "os"
+ "strconv"
+ "strings"
+)
+
+// totalSystemMemory extracts "MemTotal" from "/proc/meminfo".
+func totalSystemMemory() (uint64, error) {
+ f, err := os.Open("/proc/meminfo")
+ if err != nil {
+ return 0, err
+ }
+ defer f.Close()
+ return parseTotalSystemMemory(f)
+}
+
+func parseTotalSystemMemory(r io.Reader) (uint64, error) {
+ for scanner := bufio.NewScanner(r); scanner.Scan(); {
+ line := scanner.Text()
+ totalStr := strings.TrimPrefix(line, "MemTotal:")
+ if len(totalStr) < len(line) {
+ fields := strings.Fields(totalStr)
+ if len(fields) == 0 || len(fields) > 2 {
+ return 0, fmt.Errorf(`malformed "MemTotal": %q`, line)
+ }
+ totalStr = fields[0]
+ unit := ""
+ if len(fields) == 2 {
+ unit = fields[1]
+ }
+ mem, err := strconv.ParseUint(totalStr, 10, 64)
+ if err != nil {
+ return 0, err
+ }
+ switch unit {
+ case "":
+ // do nothing.
+ case "kB":
+ memKb := mem
+ mem = memKb * 1024
+ if mem < memKb {
+ return 0, fmt.Errorf(`"MemTotal" too large: %d`, memKb)
+ }
+ default:
+ return 0, fmt.Errorf("unknown unit %q: %q", unit, line)
+ }
+ return mem, nil
+ }
+ }
+ return 0, fmt.Errorf(`malformed "/proc/meminfo": "MemTotal" not found`)
+}
diff --git a/runsc/sandbox/memory_test.go b/runsc/sandbox/memory_test.go
new file mode 100644
index 000000000..81dc67881
--- /dev/null
+++ b/runsc/sandbox/memory_test.go
@@ -0,0 +1,91 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sandbox
+
+import (
+ "bytes"
+ "fmt"
+ "math"
+ "strings"
+ "testing"
+)
+
+func TestTotalSystemMemory(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ content string
+ want uint64
+ err string
+ }{
+ {
+ name: "simple",
+ content: "MemTotal: 123\n",
+ want: 123,
+ },
+ {
+ name: "kb",
+ content: "MemTotal: 123 kB\n",
+ want: 123 * 1024,
+ },
+ {
+ name: "multi-line",
+ content: "Something: 123\nMemTotal: 456\nAnotherThing: 789\n",
+ want: 456,
+ },
+ {
+ name: "not-found",
+ content: "Something: 123 kB\nAnotherThing: 789 kB\n",
+ err: "not found",
+ },
+ {
+ name: "no-number",
+ content: "MemTotal: \n",
+ err: "malformed",
+ },
+ {
+ name: "only-unit",
+ content: "MemTotal: kB\n",
+ err: "invalid syntax",
+ },
+ {
+ name: "negative",
+ content: "MemTotal: -1\n",
+ err: "invalid syntax",
+ },
+ {
+ name: "overflow",
+ content: fmt.Sprintf("MemTotal: %d kB\n", uint64(math.MaxUint64)),
+ err: "too large",
+ },
+ {
+ name: "unkown-unit",
+ content: "MemTotal: 123 mB\n",
+ err: "unknown unit",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ mem, err := parseTotalSystemMemory(bytes.NewReader([]byte(tc.content)))
+ if len(tc.err) > 0 {
+ if err == nil || !strings.Contains(err.Error(), tc.err) {
+ t.Errorf("parseTotalSystemMemory(%q) invalid error: %v, want: %v", tc.content, err, tc.err)
+ }
+ } else {
+ if tc.want != mem {
+ t.Errorf("parseTotalSystemMemory(%q) got: %v, want: %v", tc.content, mem, tc.want)
+ }
+ }
+ })
+ }
+}
diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go
index 9fbce6bd6..714919139 100644
--- a/runsc/sandbox/sandbox.go
+++ b/runsc/sandbox/sandbox.go
@@ -490,6 +490,61 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn
cmd.Args = append(cmd.Args, "--start-sync-fd="+strconv.Itoa(nextFD))
nextFD++
+ if conf.ProfileBlock != "" {
+ blockFile, err := os.OpenFile(conf.ProfileBlock, os.O_CREATE|os.O_WRONLY, 0644)
+ if err != nil {
+ return fmt.Errorf("opening block profiling file %q: %v", conf.ProfileBlock, err)
+ }
+ defer blockFile.Close()
+ cmd.ExtraFiles = append(cmd.ExtraFiles, blockFile)
+ cmd.Args = append(cmd.Args, "--profile-block-fd="+strconv.Itoa(nextFD))
+ nextFD++
+ }
+
+ if conf.ProfileCPU != "" {
+ cpuFile, err := os.OpenFile(conf.ProfileCPU, os.O_CREATE|os.O_WRONLY, 0644)
+ if err != nil {
+ return fmt.Errorf("opening cpu profiling file %q: %v", conf.ProfileCPU, err)
+ }
+ defer cpuFile.Close()
+ cmd.ExtraFiles = append(cmd.ExtraFiles, cpuFile)
+ cmd.Args = append(cmd.Args, "--profile-cpu-fd="+strconv.Itoa(nextFD))
+ nextFD++
+ }
+
+ if conf.ProfileHeap != "" {
+ heapFile, err := os.OpenFile(conf.ProfileHeap, os.O_CREATE|os.O_WRONLY, 0644)
+ if err != nil {
+ return fmt.Errorf("opening heap profiling file %q: %v", conf.ProfileHeap, err)
+ }
+ defer heapFile.Close()
+ cmd.ExtraFiles = append(cmd.ExtraFiles, heapFile)
+ cmd.Args = append(cmd.Args, "--profile-heap-fd="+strconv.Itoa(nextFD))
+ nextFD++
+ }
+
+ if conf.ProfileMutex != "" {
+ mutexFile, err := os.OpenFile(conf.ProfileMutex, os.O_CREATE|os.O_WRONLY, 0644)
+ if err != nil {
+ return fmt.Errorf("opening mutex profiling file %q: %v", conf.ProfileMutex, err)
+ }
+ defer mutexFile.Close()
+ cmd.ExtraFiles = append(cmd.ExtraFiles, mutexFile)
+ cmd.Args = append(cmd.Args, "--profile-mutex-fd="+strconv.Itoa(nextFD))
+ nextFD++
+ }
+
+ if conf.TraceFile != "" {
+ traceFile, err := os.OpenFile(conf.TraceFile, os.O_CREATE|os.O_WRONLY, 0644)
+ if err != nil {
+ return fmt.Errorf("opening trace file %q: %v", conf.TraceFile, err)
+ }
+ defer traceFile.Close()
+ cmd.ExtraFiles = append(cmd.ExtraFiles, traceFile)
+ cmd.Args = append(cmd.Args, "--trace-fd="+strconv.Itoa(nextFD))
+ nextFD++
+ }
+
// If there is a gofer, sends all socket ends to the sandbox.
for _, f := range args.IOFiles {
defer f.Close()
@@ -703,6 +758,11 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn
// shown as `exe`.
cmd.Args[0] = "runsc-sandbox"
+ mem, err := totalSystemMemory()
+ if err != nil {
+ return err
+ }
+
if s.Cgroup != nil {
cpuNum, err := s.Cgroup.NumCPU()
if err != nil {
@@ -730,16 +790,15 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn
}
cmd.Args = append(cmd.Args, "--cpu-num", strconv.Itoa(cpuNum))
- mem, err := s.Cgroup.MemoryLimit()
+ memLimit, err := s.Cgroup.MemoryLimit()
if err != nil {
return fmt.Errorf("getting memory limit from cgroups: %v", err)
}
- // When memory limit is unset, a "large" number is returned. In that case,
- // just stick with the default.
- if mem < 0x7ffffffffffff000 {
- cmd.Args = append(cmd.Args, "--total-memory", strconv.FormatUint(mem, 10))
+ if memLimit < mem {
+ mem = memLimit
}
}
+ cmd.Args = append(cmd.Args, "--total-memory", strconv.FormatUint(mem, 10))
if args.UserLog != "" {
f, err := os.OpenFile(args.UserLog, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0664)
diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go
index 5365b5b1b..9d3b97277 100644
--- a/runsc/specutils/specutils.go
+++ b/runsc/specutils/specutils.go
@@ -332,9 +332,9 @@ func capsFromNames(names []string, skipSet map[linux.Capability]struct{}) (auth.
return auth.CapabilitySetOfMany(caps), nil
}
-// Is9PMount returns true if the given mount can be mounted as an external
+// IsGoferMount returns true if the given mount can be mounted as an external
// gofer.
-func Is9PMount(m specs.Mount, vfs2Enabled bool) bool {
+func IsGoferMount(m specs.Mount, vfs2Enabled bool) bool {
MaybeConvertToBindMount(&m)
return m.Type == "bind" && m.Source != "" && IsSupportedDevMount(m, vfs2Enabled)
}
diff --git a/test/benchmarks/tcp/tcp_proxy.go b/test/benchmarks/tcp/tcp_proxy.go
index be74e4d4a..8273b12a7 100644
--- a/test/benchmarks/tcp/tcp_proxy.go
+++ b/test/benchmarks/tcp/tcp_proxy.go
@@ -208,8 +208,12 @@ func newNetstackImpl(mode string) (impl, error) {
if err := s.CreateNIC(nicID, fifo.New(ep, runtime.GOMAXPROCS(0), 1000)); err != nil {
return nil, fmt.Errorf("error creating NIC %q: %v", *iface, err)
}
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, parsedAddr); err != nil {
- return nil, fmt.Errorf("error adding IP address to %q: %v", *iface, err)
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: parsedAddr.WithPrefix(),
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
+ return nil, fmt.Errorf("error adding IP address %+v to %q: %s", protocolAddr, *iface, err)
}
subnet, err := tcpip.NewSubnet(parsedDest, parsedMask)
diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go
index d41139944..c5aa35623 100644
--- a/test/e2e/integration_test.go
+++ b/test/e2e/integration_test.go
@@ -29,6 +29,7 @@ import (
"net"
"net/http"
"os"
+ "os/exec"
"path/filepath"
"regexp"
"strconv"
@@ -41,8 +42,12 @@ import (
"gvisor.dev/gvisor/pkg/test/testutil"
)
-// defaultWait is the default wait time used for tests.
-const defaultWait = time.Minute
+const (
+ // defaultWait is the default wait time used for tests.
+ defaultWait = time.Minute
+
+ memInfoCmd = "cat /proc/meminfo | grep MemTotal: | awk '{print $2}'"
+)
func TestMain(m *testing.M) {
dockerutil.EnsureSupportedDockerVersion()
@@ -255,16 +260,47 @@ func TestConnectToSelf(t *testing.T) {
}
}
+func TestMemory(t *testing.T) {
+ // Find total amount of memory in the host.
+ host, err := exec.Command("sh", "-c", memInfoCmd).CombinedOutput()
+ if err != nil {
+ t.Fatal(err)
+ }
+ want, err := strconv.ParseUint(strings.TrimSpace(string(host)), 10, 64)
+ if err != nil {
+ t.Fatalf("failed to parse %q: %v", host, err)
+ }
+
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ out, err := d.Run(ctx, dockerutil.RunOpts{Image: "basic/alpine"}, "sh", "-c", memInfoCmd)
+ if err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Get memory from inside the container and ensure it matches the host.
+ got, err := strconv.ParseUint(strings.TrimSpace(out), 10, 64)
+ if err != nil {
+ t.Fatalf("failed to parse %q: %v", out, err)
+ }
+ if got != want {
+ t.Errorf("MemTotal got: %d, want: %d", got, want)
+ }
+}
+
func TestMemLimit(t *testing.T) {
ctx := context.Background()
d := dockerutil.MakeContainer(ctx, t)
defer d.CleanUp(ctx)
allocMemoryKb := 50 * 1024
- out, err := d.Run(ctx, dockerutil.RunOpts{
+ opts := dockerutil.RunOpts{
Image: "basic/alpine",
Memory: allocMemoryKb * 1024, // In bytes.
- }, "sh", "-c", "cat /proc/meminfo | grep MemTotal: | awk '{print $2}'")
+ }
+ out, err := d.Run(ctx, opts, "sh", "-c", memInfoCmd)
if err != nil {
t.Fatalf("docker run failed: %v", err)
}
diff --git a/test/packetimpact/runner/dut.go b/test/packetimpact/runner/dut.go
index 02678a76a..f27e52f93 100644
--- a/test/packetimpact/runner/dut.go
+++ b/test/packetimpact/runner/dut.go
@@ -331,6 +331,22 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co
t.Logf("sniffer logs:\n%s", snifferOut)
}
})
+ }
+
+ // Arm the cleanup hook before we do anthing else. Otherwise failures below
+ // can cause the test to hang in the cleanup hook above.
+ t.Cleanup(func() {
+ // Wait 1 second before killing tcpdump to give it time to flush
+ // any packets. On linux tests killing it immediately can
+ // sometimes result in partial pcaps.
+ time.Sleep(1 * time.Second)
+ if logs, err := testbenchContainer.Exec(ctx, dockerutil.ExecOpts{}, "killall", baseSnifferArgs[0]); err != nil {
+ t.Errorf("failed to kill all sniffers: %s, logs: %s", err, logs)
+ }
+ })
+
+ for _, info := range dutInfos {
+ n := info.Net
// When the Linux kernel receives a SYN-ACK for a SYN it didn't send, it
// will respond with an RST. In most packetimpact tests, the SYN is sent
// by the raw socket, the kernel knows nothing about the connection, this
@@ -344,16 +360,6 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co
}
}
- t.Cleanup(func() {
- // Wait 1 second before killing tcpdump to give it time to flush
- // any packets. On linux tests killing it immediately can
- // sometimes result in partial pcaps.
- time.Sleep(1 * time.Second)
- if logs, err := testbenchContainer.Exec(ctx, dockerutil.ExecOpts{}, "killall", baseSnifferArgs[0]); err != nil {
- t.Errorf("failed to kill all sniffers: %s, logs: %s", err, logs)
- }
- })
-
// FIXME(b/156449515): Some piece of the system has a race. The old
// bash script version had a sleep, so we have one too. The race should
// be fixed and this sleep removed.
diff --git a/test/packetimpact/tests/tcp_listen_backlog_test.go b/test/packetimpact/tests/tcp_listen_backlog_test.go
index fea7d5b6f..e124002f6 100644
--- a/test/packetimpact/tests/tcp_listen_backlog_test.go
+++ b/test/packetimpact/tests/tcp_listen_backlog_test.go
@@ -15,7 +15,9 @@
package tcp_listen_backlog_test
import (
+ "bytes"
"flag"
+ "sync"
"testing"
"time"
@@ -35,60 +37,272 @@ func init() {
func TestTCPListenBacklog(t *testing.T) {
dut := testbench.NewDUT(t)
- // Listening endpoint accepts one more connection than the listen backlog.
- listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 0 /*backlog*/)
+ // This is the number of pending connections before SYN cookies are used.
+ const backlog = 10
- var establishedConn testbench.TCPIPv4
- var incompleteConn testbench.TCPIPv4
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, backlog)
+ defer dut.Close(t, listenFd)
- // Test if the DUT listener replies to more SYNs than listen backlog+1
- for i, conn := range []*testbench.TCPIPv4{&establishedConn, &incompleteConn} {
- *conn = dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
- // Expect dut connection to have transitioned to SYN-RCVD state.
+ // Fill the SYN queue with connections in SYN-RCVD. We will use these to test
+ // that ACKs received while the accept queue is full are ignored.
+ var synQueueConns [backlog]testbench.TCPIPv4
+ defer func() {
+ for i := range synQueueConns {
+ synQueueConns[i].Close(t)
+ }
+ }()
+ {
+ var wg sync.WaitGroup
+ for i := range synQueueConns {
+ conn := &synQueueConns[i]
+ *conn = dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{})
+
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)})
+ if got, err := conn.Expect(t, testbench.TCP{}, time.Second); err != nil {
+ t.Errorf("%d: expected TCP frame: %s", i, err)
+ } else if got, want := *got.Flags, header.TCPFlagSyn|header.TCPFlagAck; got != want {
+ t.Errorf("%d: got %s, want %s", i, got, want)
+ }
+ }(i)
+ }
+ wg.Wait()
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ const payloadLen = 1
+ payload := testbench.Payload{Bytes: testbench.GenerateRandomPayload(t, payloadLen)}
+
+ // Fill the accept queue with connections established using SYN cookies.
+ var synCookieConns [backlog + 1]testbench.TCPIPv4
+ defer func() {
+ for i := range synCookieConns {
+ synCookieConns[i].Close(t)
+ }
+ }()
+ {
+ var wg sync.WaitGroup
+ for i := range synCookieConns {
+ conn := &synCookieConns[i]
+ *conn = dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{})
+
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)})
+ if got, err := conn.Expect(t, testbench.TCP{}, time.Second); err != nil {
+ t.Errorf("%d: expected TCP frame: %s", i, err)
+ } else if got, want := *got.Flags, header.TCPFlagSyn|header.TCPFlagAck; got != want {
+ t.Errorf("%d: got %s, want %s", i, got, want)
+ }
+ // Send a payload so we can observe the dut ACK.
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, &payload)
+ if got, err := conn.Expect(t, testbench.TCP{}, time.Second); err != nil {
+ t.Errorf("%d: expected TCP frame: %s", i, err)
+ } else if got, want := *got.Flags, header.TCPFlagAck; got != want {
+ t.Errorf("%d: got %s, want %s", i, got, want)
+ }
+ }(i)
+ }
+ wg.Wait()
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ // Send ACKs to complete the handshakes. These are expected to be dropped
+ // because the accept queue is full.
+ {
+ var wg sync.WaitGroup
+ for i := range synQueueConns {
+ conn := &synQueueConns[i]
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
+ // Wait for the SYN-ACK to be retransmitted to confirm the ACK was
+ // dropped.
+ seqNum := uint32(*conn.RemoteSeqNum(t) - 1)
+ if got, err := conn.Expect(t, testbench.TCP{SeqNum: &seqNum}, time.Second); err != nil {
+ t.Errorf("%d: expected TCP frame: %s", i, err)
+ } else if got, want := *got.Flags, header.TCPFlagSyn|header.TCPFlagAck; got != want {
+ t.Errorf("%d: got %s, want %s", i, got, want)
+ }
+ }(i)
+ }
+
+ wg.Wait()
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ // While the accept queue is still full, send an unexpected ACK from a new
+ // socket. The listener should reply with an RST.
+ func() {
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{})
+ defer conn.Close(t)
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
+ if got, err := conn.Expect(t, testbench.TCP{}, time.Second); err != nil {
+ t.Errorf("expected TCP frame: %s", err)
+ } else if got, want := *got.Flags, header.TCPFlagRst; got != want {
+ t.Errorf("got %s, want %s", got, want)
+ }
+ }()
+
+ func() {
+ // Now initiate a new connection when the accept queue is full.
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{})
+ defer conn.Close(t)
+ // Expect dut connection to drop the SYN.
conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)})
- if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
- t.Fatalf("expected SYN-ACK for %d connection, %s", i, err)
+ if got, err := conn.Expect(t, testbench.TCP{}, time.Second); err == nil {
+ t.Fatalf("expected no TCP frame, got %s", got)
+ }
+ }()
+
+ // Drain the accept queue.
+ {
+ var wg sync.WaitGroup
+ for i := range synCookieConns {
+ conn := &synCookieConns[i]
+
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+
+ fd, _ := dut.Accept(t, listenFd)
+ b := dut.Recv(t, fd, payloadLen+1, 0)
+ dut.Close(t, fd)
+ if !bytes.Equal(b, payload.Bytes) {
+ t.Errorf("connection %d: got dut.Recv = %x, want = %x", i, b, payload.Bytes)
+ }
+
+ if got, err := conn.Expect(t, testbench.TCP{}, time.Second); err != nil {
+ t.Errorf("%d: expected TCP frame: %s", i, err)
+ } else if got, want := *got.Flags, header.TCPFlagFin|header.TCPFlagAck; got != want {
+ t.Errorf("%d: got %s, want %s", i, got, want)
+ }
+
+ // Prevent retransmission.
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
+ }(i)
+ }
+ wg.Wait()
+ if t.Failed() {
+ t.FailNow()
}
}
- defer establishedConn.Close(t)
- defer incompleteConn.Close(t)
-
- // Send the ACK to complete handshake.
- establishedConn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
-
- // Poll for the established connection ready for accept.
- dut.PollOne(t, listenFd, unix.POLLIN, time.Second)
-
- // Send the ACK to complete handshake, expect this to be dropped by the
- // listener as the accept queue would be full because of the previous
- // handshake.
- incompleteConn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
- // Let the test wait for sometime so that the ACK is indeed dropped by
- // the listener. Without such a wait, the DUT accept can race with
- // ACK handling (dropping) causing the test to be flaky.
- time.Sleep(100 * time.Millisecond)
-
- // Drain the accept queue to enable poll for subsequent connections on the
- // listener.
- fd, _ := dut.Accept(t, listenFd)
- dut.Close(t, fd)
-
- // The ACK for the incomplete connection should be ignored by the
- // listening endpoint and the poll on listener should now time out.
- if pfds := dut.Poll(t, []unix.PollFd{{Fd: listenFd, Events: unix.POLLIN}}, time.Second); len(pfds) != 0 {
- t.Fatalf("got dut.Poll(...) = %#v", pfds)
+
+ // Complete the partial connections to move them from the SYN queue to the
+ // accept queue. We will use these to test that connections in the accept
+ // queue are closed on listener shutdown.
+ {
+ var wg sync.WaitGroup
+ for i := range synQueueConns {
+ conn := &synQueueConns[i]
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+
+ tcp := testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}
+
+ // Exercise connections with and without pending data.
+ if i%2 == 0 {
+ // Send ACK with no payload; wait for absence of SYN-ACK retransmit.
+ conn.Send(t, tcp)
+ if got, err := conn.Expect(t, testbench.TCP{}, time.Second); err == nil {
+ t.Errorf("%d: expected no TCP frame, got %s", i, got)
+ }
+ } else {
+ // Send ACK with payload; wait for ACK.
+ conn.Send(t, tcp, &payload)
+ if got, err := conn.Expect(t, testbench.TCP{}, time.Second); err != nil {
+ t.Errorf("%d: expected TCP frame: %s", i, err)
+ } else if got, want := *got.Flags, header.TCPFlagAck; got != want {
+ t.Errorf("%d: got %s, want %s", i, got, want)
+ }
+ }
+ }(i)
+ }
+
+ wg.Wait()
+ if t.Failed() {
+ t.FailNow()
+ }
}
- // Re-send the ACK to complete handshake and re-fill the accept-queue.
- incompleteConn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
- dut.PollOne(t, listenFd, unix.POLLIN, time.Second)
-
- // Now initiate a new connection when the accept queue is full.
- connectingConn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
- defer connectingConn.Close(t)
- // Expect dut connection to drop the SYN and let the client stay in SYN_SENT state.
- connectingConn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)})
- if got, err := connectingConn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err == nil {
- t.Fatalf("expected no SYN-ACK, but got %s", got)
+ // The accept queue now has N-1 connections in it. The next incoming SYN will
+ // enter the SYN queue, and the one following will use SYN cookies. We test
+ // both.
+ var connectingConns [2]testbench.TCPIPv4
+ defer func() {
+ for i := range connectingConns {
+ connectingConns[i].Close(t)
+ }
+ }()
+ {
+ var wg sync.WaitGroup
+ for i := range connectingConns {
+ conn := &connectingConns[i]
+ *conn = dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{})
+
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)})
+ if got, err := conn.Expect(t, testbench.TCP{}, time.Second); err != nil {
+ t.Errorf("%d: expected TCP frame: %s", i, err)
+ } else if got, want := *got.Flags, header.TCPFlagSyn|header.TCPFlagAck; got != want {
+ t.Errorf("%d: got %s, want %s", i, got, want)
+ }
+ }(i)
+ }
+ wg.Wait()
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ dut.Shutdown(t, listenFd, unix.SHUT_RD)
+
+ var wg sync.WaitGroup
+
+ // Shutdown causes Connections in the accept queue to be closed.
+ for i := range synQueueConns {
+ conn := &synQueueConns[i]
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+
+ if got, err := conn.Expect(t, testbench.TCP{}, time.Second); err != nil {
+ t.Errorf("%d: expected TCP frame: %s", i, err)
+ } else if got, want := *got.Flags, header.TCPFlagRst|header.TCPFlagAck; got != want {
+ t.Errorf("%d: got %s, want %s", i, got, want)
+ }
+ }(i)
+ }
+
+ for i := range connectingConns {
+ conn := &connectingConns[i]
+
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+
+ if got, err := conn.Expect(t, testbench.TCP{}, time.Second); err == nil {
+ t.Errorf("%d: expected no TCP frame, got %s", i, got)
+ }
+ }(i)
}
+
+ wg.Wait()
}
diff --git a/test/runner/main.go b/test/runner/main.go
index 2cab8d2d4..c155431de 100644
--- a/test/runner/main.go
+++ b/test/runner/main.go
@@ -180,11 +180,9 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error {
if *overlay {
args = append(args, "-overlay")
}
- if *vfs2 {
- args = append(args, "-vfs2")
- if *fuse {
- args = append(args, "-fuse")
- }
+ args = append(args, fmt.Sprintf("-vfs2=%t", *vfs2))
+ if *vfs2 && *fuse {
+ args = append(args, "-fuse")
}
if *debug {
args = append(args, "-debug", "-log-packets=true")
diff --git a/test/runtimes/runner/lib/BUILD b/test/runtimes/runner/lib/BUILD
index d308f41b0..3491c535b 100644
--- a/test/runtimes/runner/lib/BUILD
+++ b/test/runtimes/runner/lib/BUILD
@@ -5,7 +5,11 @@ package(licenses = ["notice"])
go_library(
name = "lib",
testonly = 1,
- srcs = ["lib.go"],
+ srcs = [
+ "go_test_dependency_go118.go",
+ "go_test_dependency_not_go118.go",
+ "lib.go",
+ ],
visibility = ["//test/runtimes/runner:__pkg__"],
deps = [
"//pkg/log",
diff --git a/test/runtimes/runner/lib/go_test_dependency_go118.go b/test/runtimes/runner/lib/go_test_dependency_go118.go
new file mode 100644
index 000000000..d430e81c7
--- /dev/null
+++ b/test/runtimes/runner/lib/go_test_dependency_go118.go
@@ -0,0 +1,27 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18 && go1.1
+// +build go1.18,go1.1
+
+package lib
+
+import (
+ "testing"
+)
+
+// mainStart wraps testing.MainStart for Go release == 1.18.
+func mainStart(tests []testing.InternalTest) *testing.M {
+ return testing.MainStart(testDeps{}, tests, nil, nil, nil)
+}
diff --git a/test/runtimes/runner/lib/go_test_dependency_not_go118.go b/test/runtimes/runner/lib/go_test_dependency_not_go118.go
new file mode 100644
index 000000000..8b0b34c72
--- /dev/null
+++ b/test/runtimes/runner/lib/go_test_dependency_not_go118.go
@@ -0,0 +1,25 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build !go1.18 && go1.1
+// +build !go1.18,go1.1
+
+package lib
+
+import "testing"
+
+// mainStart wraps testing.MainStart for Go release < 1.18.
+func mainStart(tests []testing.InternalTest) *testing.M {
+ return testing.MainStart(testDeps{}, tests, nil, nil)
+}
diff --git a/test/runtimes/runner/lib/lib.go b/test/runtimes/runner/lib/lib.go
index d6b652897..6b95b4cfa 100644
--- a/test/runtimes/runner/lib/lib.go
+++ b/test/runtimes/runner/lib/lib.go
@@ -21,6 +21,7 @@ import (
"fmt"
"io"
"os"
+ "reflect"
"sort"
"strings"
"testing"
@@ -63,8 +64,7 @@ func RunTests(lang, image, excludeFile string, batchSize int, timeout time.Durat
fmt.Fprintf(os.Stderr, "%s\n", err.Error())
return 1
}
-
- m := testing.MainStart(testDeps{}, tests, nil, nil)
+ m := mainStart(tests)
return m.Run()
}
@@ -197,3 +197,21 @@ func (f testDeps) ImportPath() string { return "" }
func (f testDeps) StartTestLog(io.Writer) {}
func (f testDeps) StopTestLog() error { return nil }
func (f testDeps) SetPanicOnExit0(bool) {}
+func (f testDeps) CoordinateFuzzing(time.Duration, int64, time.Duration, int64, int, []corpusEntry, []reflect.Type, string, string) error {
+ return nil
+}
+func (f testDeps) RunFuzzWorker(func(corpusEntry) error) error { return nil }
+func (f testDeps) ReadCorpus(string, []reflect.Type) ([]corpusEntry, error) { return nil, nil }
+func (f testDeps) CheckCorpus([]interface{}, []reflect.Type) error { return nil }
+func (f testDeps) ResetCoverage() {}
+func (f testDeps) SnapshotCoverage() {}
+
+// Copied from testing/fuzz.go.
+type corpusEntry = struct {
+ Parent string
+ Path string
+ Data []byte
+ Values []interface{}
+ Generation int
+ IsSeed bool
+}
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index 43854e80b..f748d685a 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -363,6 +363,10 @@ syscall_test(
)
syscall_test(
+ test = "//test/syscalls/linux:packet_socket_dgram_test",
+)
+
+syscall_test(
test = "//test/syscalls/linux:packet_socket_raw_test",
)
@@ -434,6 +438,11 @@ syscall_test(
)
syscall_test(
+ tags = ["container"],
+ test = "//test/syscalls/linux:proc_isolated_test",
+)
+
+syscall_test(
test = "//test/syscalls/linux:proc_net_test",
)
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 8d501a3a3..9d975c614 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -8,6 +8,7 @@ package(
exports_files(
[
"packet_socket.cc",
+ "packet_socket_dgram.cc",
"packet_socket_raw.cc",
"raw_socket.cc",
"raw_socket_hdrincl.cc",
@@ -589,6 +590,7 @@ cc_binary(
"//test/util:posix_error",
"//test/util:test_main",
"//test/util:test_util",
+ "//test/util:thread_util",
],
)
@@ -1441,10 +1443,9 @@ cc_binary(
)
cc_binary(
- name = "packet_socket_raw_test",
+ name = "packet_socket_dgram_test",
testonly = 1,
- srcs = ["packet_socket_raw.cc"],
- defines = select_system(),
+ srcs = ["packet_socket_dgram.cc"],
linkstatic = 1,
deps = [
":ip_socket_test_util",
@@ -1462,9 +1463,9 @@ cc_binary(
)
cc_binary(
- name = "packet_socket_test",
+ name = "packet_socket_raw_test",
testonly = 1,
- srcs = ["packet_socket.cc"],
+ srcs = ["packet_socket_raw.cc"],
linkstatic = 1,
deps = [
":ip_socket_test_util",
@@ -1482,6 +1483,21 @@ cc_binary(
)
cc_binary(
+ name = "packet_socket_test",
+ testonly = 1,
+ srcs = ["packet_socket.cc"],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "//test/util:socket_util",
+ gtest,
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
name = "pty_test",
testonly = 1,
srcs = ["pty.cc"],
@@ -1778,6 +1794,20 @@ cc_binary(
)
cc_binary(
+ name = "proc_isolated_test",
+ testonly = 1,
+ srcs = ["proc_isolated.cc"],
+ linkstatic = 1,
+ deps = [
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:fs_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
name = "proc_net_test",
testonly = 1,
srcs = ["proc_net.cc"],
@@ -1949,6 +1979,7 @@ cc_binary(
defines = select_system(),
linkstatic = 1,
deps = [
+ ":ip_socket_test_util",
":unix_domain_socket_test_util",
"//test/util:capability_util",
"//test/util:file_descriptor",
@@ -2631,16 +2662,11 @@ cc_library(
cc_library(
name = "socket_ip_udp_unbound_external_networking",
testonly = 1,
- srcs = [
- "socket_ip_udp_unbound_external_networking.cc",
- ],
hdrs = [
"socket_ip_udp_unbound_external_networking.h",
],
deps = [
":ip_socket_test_util",
- "//test/util:socket_util",
- "//test/util:test_util",
],
alwayslink = 1,
)
@@ -2656,6 +2682,9 @@ cc_library(
],
deps = [
":socket_ip_udp_unbound_external_networking",
+ "//test/util:socket_util",
+ "//test/util:test_util",
+ "@com_google_absl//absl/cleanup",
gtest,
],
alwayslink = 1,
@@ -4174,17 +4203,17 @@ cc_binary(
name = "mq_test",
testonly = 1,
srcs = ["mq.cc"],
- linkstatic = 1,
linkopts = [
"-lrt",
],
+ linkstatic = 1,
deps = [
"//test/util:capability_util",
"//test/util:fs_util",
"//test/util:mount_util",
"//test/util:posix_error",
- "//test/util:test_main",
"//test/util:temp_path",
+ "//test/util:test_main",
"//test/util:test_util",
],
)
diff --git a/test/syscalls/linux/cgroup.cc b/test/syscalls/linux/cgroup.cc
index ca23dfeee..50705a86e 100644
--- a/test/syscalls/linux/cgroup.cc
+++ b/test/syscalls/linux/cgroup.cc
@@ -37,6 +37,8 @@ namespace {
using ::testing::_;
using ::testing::Contains;
+using ::testing::Each;
+using ::testing::Eq;
using ::testing::Ge;
using ::testing::Gt;
using ::testing::Key;
@@ -383,6 +385,32 @@ PosixError WriteAndVerifyControlValue(const Cgroup& c, std::string_view path,
return NoError();
}
+PosixErrorOr<std::vector<bool>> ParseBitmap(std::string s) {
+ std::vector<bool> bitmap;
+ bitmap.reserve(64);
+ for (const std::string_view& t : absl::StrSplit(s, ',')) {
+ std::vector<std::string> parts = absl::StrSplit(t, absl::MaxSplits('-', 2));
+ if (parts.size() == 2) {
+ ASSIGN_OR_RETURN_ERRNO(int64_t start, Atoi<int64_t>(parts[0]));
+ ASSIGN_OR_RETURN_ERRNO(int64_t end, Atoi<int64_t>(parts[1]));
+ // Note: start and end are indices into bitmap.
+ if (end >= bitmap.size()) {
+ bitmap.resize(end + 1, false);
+ }
+ for (int i = start; i <= end; ++i) {
+ bitmap[i] = true;
+ }
+ } else { // parts.size() == 1, 0 not possible.
+ ASSIGN_OR_RETURN_ERRNO(int64_t i, Atoi<int64_t>(parts[0]));
+ if (i >= bitmap.size()) {
+ bitmap.resize(i + 1, false);
+ }
+ bitmap[i] = true;
+ }
+ }
+ return bitmap;
+}
+
TEST(JobCgroup, ReadWriteRead) {
SKIP_IF(!CgroupsAvailable());
@@ -396,6 +424,59 @@ TEST(JobCgroup, ReadWriteRead) {
EXPECT_NO_ERRNO(WriteAndVerifyControlValue(c, "job.id", LLONG_MAX));
}
+TEST(CpusetCgroup, Defaults) {
+ SKIP_IF(!CgroupsAvailable());
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpuset"));
+
+ std::string cpus =
+ ASSERT_NO_ERRNO_AND_VALUE(c.ReadControlFile("cpuset.cpus"));
+ std::vector<bool> cpus_bitmap = ASSERT_NO_ERRNO_AND_VALUE(ParseBitmap(cpus));
+ EXPECT_GT(cpus_bitmap.size(), 0);
+ EXPECT_THAT(cpus_bitmap, Each(Eq(true)));
+
+ std::string mems =
+ ASSERT_NO_ERRNO_AND_VALUE(c.ReadControlFile("cpuset.mems"));
+ std::vector<bool> mems_bitmap = ASSERT_NO_ERRNO_AND_VALUE(ParseBitmap(mems));
+ EXPECT_GT(mems_bitmap.size(), 0);
+ EXPECT_THAT(mems_bitmap, Each(Eq(true)));
+}
+
+TEST(CpusetCgroup, SetMask) {
+ SKIP_IF(!CgroupsAvailable());
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpuset"));
+
+ std::string cpus =
+ ASSERT_NO_ERRNO_AND_VALUE(c.ReadControlFile("cpuset.cpus"));
+ std::vector<bool> cpus_bitmap = ASSERT_NO_ERRNO_AND_VALUE(ParseBitmap(cpus));
+
+ SKIP_IF(cpus_bitmap.size() <= 1); // "Not enough CPUs"
+
+ int max_cpu = cpus_bitmap.size() - 1;
+ ASSERT_NO_ERRNO(
+ c.WriteControlFile("cpuset.cpus", absl::StrCat("1-", max_cpu)));
+ cpus_bitmap[0] = false;
+ cpus = ASSERT_NO_ERRNO_AND_VALUE(c.ReadControlFile("cpuset.cpus"));
+ std::vector<bool> cpus_bitmap_after =
+ ASSERT_NO_ERRNO_AND_VALUE(ParseBitmap(cpus));
+ EXPECT_EQ(cpus_bitmap_after, cpus_bitmap);
+}
+
+TEST(CpusetCgroup, SetEmptyMask) {
+ SKIP_IF(!CgroupsAvailable());
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpuset"));
+ ASSERT_NO_ERRNO(c.WriteControlFile("cpuset.cpus", ""));
+ std::string_view cpus = absl::StripAsciiWhitespace(
+ ASSERT_NO_ERRNO_AND_VALUE(c.ReadControlFile("cpuset.cpus")));
+ EXPECT_EQ(cpus, "");
+ ASSERT_NO_ERRNO(c.WriteControlFile("cpuset.mems", ""));
+ std::string_view mems = absl::StripAsciiWhitespace(
+ ASSERT_NO_ERRNO_AND_VALUE(c.ReadControlFile("cpuset.mems")));
+ EXPECT_EQ(mems, "");
+}
+
TEST(ProcCgroups, Empty) {
SKIP_IF(!CgroupsAvailable());
diff --git a/test/syscalls/linux/epoll.cc b/test/syscalls/linux/epoll.cc
index 3ef8b0327..c2dc8174c 100644
--- a/test/syscalls/linux/epoll.cc
+++ b/test/syscalls/linux/epoll.cc
@@ -30,6 +30,7 @@
#include "test/util/file_descriptor.h"
#include "test/util/posix_error.h"
#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
namespace gvisor {
namespace testing {
@@ -496,6 +497,41 @@ TEST(EpollTest, PipeReaderHupAfterWriterClosed) {
EXPECT_EQ(result[0].data.u64, kMagicConstant);
}
+TEST(EpollTest, DoubleLayerEpoll) {
+ int pipefds[2];
+ ASSERT_THAT(pipe2(pipefds, O_NONBLOCK), SyscallSucceeds());
+ FileDescriptor rfd(pipefds[0]);
+ FileDescriptor wfd(pipefds[1]);
+
+ auto epfd1 = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+ ASSERT_NO_ERRNO(
+ RegisterEpollFD(epfd1.get(), rfd.get(), EPOLLIN | EPOLLHUP, rfd.get()));
+
+ auto epfd2 = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+ ASSERT_NO_ERRNO(RegisterEpollFD(epfd2.get(), epfd1.get(), EPOLLIN | EPOLLHUP,
+ epfd1.get()));
+
+ // Write to wfd and then check if epoll events were generated correctly.
+ // Run this loop a couple of times to check if event in epfd1 is cleaned.
+ constexpr char data[] = "data";
+ for (int i = 0; i < 2; ++i) {
+ ScopedThread thread1([&wfd, &data]() {
+ sleep(1);
+ ASSERT_EQ(WriteFd(wfd.get(), data, sizeof(data)), sizeof(data));
+ });
+
+ struct epoll_event ret_events[2];
+ ASSERT_THAT(RetryEINTR(epoll_wait)(epfd2.get(), ret_events, 2, 5000),
+ SyscallSucceedsWithValue(1));
+ ASSERT_EQ(ret_events[0].data.fd, epfd1.get());
+ ASSERT_THAT(RetryEINTR(epoll_wait)(epfd1.get(), ret_events, 2, 5000),
+ SyscallSucceedsWithValue(1));
+ ASSERT_EQ(ret_events[0].data.fd, rfd.get());
+ char readBuf[sizeof(data)];
+ ASSERT_EQ(ReadFd(rfd.get(), readBuf, sizeof(data)), sizeof(data));
+ }
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/ip_socket_test_util.cc b/test/syscalls/linux/ip_socket_test_util.cc
index a1216d23f..d18e616d0 100644
--- a/test/syscalls/linux/ip_socket_test_util.cc
+++ b/test/syscalls/linux/ip_socket_test_util.cc
@@ -16,6 +16,7 @@
#include <net/if.h>
#include <netinet/in.h>
+#include <netpacket/packet.h>
#include <sys/socket.h>
#include <cstring>
@@ -196,75 +197,53 @@ SocketKind IPv6TCPUnboundSocket(int type) {
UnboundSocketCreator(AF_INET6, type | SOCK_STREAM, IPPROTO_TCP)};
}
-PosixError IfAddrHelper::Load() {
- Release();
-#ifndef ANDROID
- RETURN_ERROR_IF_SYSCALL_FAIL(getifaddrs(&ifaddr_));
-#else
- // Android does not support getifaddrs in r22.
- return PosixError(ENOSYS, "getifaddrs");
-#endif
- return NoError();
-}
-
-void IfAddrHelper::Release() {
- if (ifaddr_) {
-#ifndef ANDROID
- // Android does not support freeifaddrs in r22.
- freeifaddrs(ifaddr_);
-#endif
- ifaddr_ = nullptr;
- }
-}
-
-std::vector<std::string> IfAddrHelper::InterfaceList(int family) const {
- std::vector<std::string> names;
- for (auto ifa = ifaddr_; ifa != NULL; ifa = ifa->ifa_next) {
- if (ifa->ifa_addr == NULL || ifa->ifa_addr->sa_family != family) {
- continue;
- }
- names.emplace(names.end(), ifa->ifa_name);
- }
- return names;
-}
-
-const sockaddr* IfAddrHelper::GetAddr(int family, std::string name) const {
- for (auto ifa = ifaddr_; ifa != NULL; ifa = ifa->ifa_next) {
- if (ifa->ifa_addr == NULL || ifa->ifa_addr->sa_family != family) {
- continue;
- }
- if (name == ifa->ifa_name) {
- return ifa->ifa_addr;
- }
- }
- return nullptr;
-}
-
-PosixErrorOr<int> IfAddrHelper::GetIndex(std::string name) const {
- return InterfaceIndex(name);
-}
-
std::string GetAddr4Str(const in_addr* a) {
char str[INET_ADDRSTRLEN];
- inet_ntop(AF_INET, a, str, sizeof(str));
- return std::string(str);
+ return inet_ntop(AF_INET, a, str, sizeof(str));
}
std::string GetAddr6Str(const in6_addr* a) {
char str[INET6_ADDRSTRLEN];
- inet_ntop(AF_INET6, a, str, sizeof(str));
- return std::string(str);
+ return inet_ntop(AF_INET6, a, str, sizeof(str));
}
std::string GetAddrStr(const sockaddr* a) {
- if (a->sa_family == AF_INET) {
- auto src = &(reinterpret_cast<const sockaddr_in*>(a)->sin_addr);
- return GetAddr4Str(src);
- } else if (a->sa_family == AF_INET6) {
- auto src = &(reinterpret_cast<const sockaddr_in6*>(a)->sin6_addr);
- return GetAddr6Str(src);
+ switch (a->sa_family) {
+ case AF_INET: {
+ return GetAddr4Str(&(reinterpret_cast<const sockaddr_in*>(a)->sin_addr));
+ }
+ case AF_INET6: {
+ return GetAddr6Str(
+ &(reinterpret_cast<const sockaddr_in6*>(a)->sin6_addr));
+ }
+ case AF_PACKET: {
+ const sockaddr_ll& ll = *reinterpret_cast<const sockaddr_ll*>(a);
+ std::ostringstream ss;
+ ss << std::hex;
+ ss << std::showbase;
+ ss << '{';
+ ss << " protocol=" << ntohs(ll.sll_protocol);
+ ss << " ifindex=" << ll.sll_ifindex;
+ ss << " hatype=" << ll.sll_hatype;
+ ss << " pkttype=" << static_cast<unsigned short>(ll.sll_pkttype);
+ if (ll.sll_halen != 0) {
+ ss << " addr=";
+ for (unsigned char i = 0; i < ll.sll_halen; ++i) {
+ if (i != 0) {
+ ss << ':';
+ }
+ ss << static_cast<unsigned short>(ll.sll_addr[i]);
+ }
+ }
+ ss << " }";
+ return ss.str();
+ }
+ default: {
+ std::ostringstream ss;
+ ss << "invalid(sa_family=" << a->sa_family << ")";
+ return ss.str();
+ }
}
- return std::string("<invalid>");
}
} // namespace testing
diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h
index 556838356..957006e25 100644
--- a/test/syscalls/linux/ip_socket_test_util.h
+++ b/test/syscalls/linux/ip_socket_test_util.h
@@ -115,25 +115,6 @@ SocketKind IPv4TCPUnboundSocket(int type);
// created with AF_INET6, SOCK_STREAM, IPPROTO_TCP and the given type.
SocketKind IPv6TCPUnboundSocket(int type);
-// IfAddrHelper is a helper class that determines the local interfaces present
-// and provides functions to obtain their names, index numbers, and IP address.
-class IfAddrHelper {
- public:
- IfAddrHelper() : ifaddr_(nullptr) {}
- ~IfAddrHelper() { Release(); }
-
- PosixError Load();
- void Release();
-
- std::vector<std::string> InterfaceList(int family) const;
-
- const sockaddr* GetAddr(int family, std::string name) const;
- PosixErrorOr<int> GetIndex(std::string name) const;
-
- private:
- struct ifaddrs* ifaddr_;
-};
-
// GetAddr4Str returns the given IPv4 network address structure as a string.
std::string GetAddr4Str(const in_addr* a);
diff --git a/test/syscalls/linux/mmap.cc b/test/syscalls/linux/mmap.cc
index fda176261..bb2a0cb57 100644
--- a/test/syscalls/linux/mmap.cc
+++ b/test/syscalls/linux/mmap.cc
@@ -29,6 +29,7 @@
#include <sys/wait.h>
#include <unistd.h>
+#include <limits>
#include <vector>
#include "gmock/gmock.h"
@@ -913,13 +914,41 @@ TEST_F(MMapFileTest, MapOffsetBeyondEnd) {
::testing::KilledBySignal(SIGBUS), "");
}
-// Verify mmap fails when sum of length and offset overflows.
-TEST_F(MMapFileTest, MapLengthPlusOffsetOverflows) {
+TEST_F(MMapFileTest, MapSecondToLastPositivePage) {
SKIP_IF(!FSSupportsMap());
- const size_t length = static_cast<size_t>(-kPageSize);
- const off_t offset = kPageSize;
- ASSERT_THAT(Map(0, length, PROT_READ, MAP_PRIVATE, fd_.get(), offset),
- SyscallFailsWithErrno(ENOMEM));
+ EXPECT_THAT(
+ Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(),
+ (std::numeric_limits<off_t>::max() - kPageSize) & ~(kPageSize - 1)),
+ SyscallSucceeds());
+}
+
+TEST_F(MMapFileTest, MapLastPositivePage) {
+ SKIP_IF(!FSSupportsMap());
+ // For regular files, this should fail due to integer overflow of the end
+ // offset.
+ EXPECT_THAT(Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(),
+ std::numeric_limits<off_t>::max() & ~(kPageSize - 1)),
+ SyscallFailsWithErrno(EOVERFLOW));
+}
+
+TEST_F(MMapFileTest, MapFirstNegativePage) {
+ SKIP_IF(!FSSupportsMap());
+ EXPECT_THAT(Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(),
+ std::numeric_limits<off_t>::min()),
+ SyscallFailsWithErrno(EOVERFLOW));
+}
+
+TEST_F(MMapFileTest, MapSecondToLastNegativePage) {
+ SKIP_IF(!FSSupportsMap());
+ EXPECT_THAT(
+ Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(), -(2 * kPageSize)),
+ SyscallFailsWithErrno(EOVERFLOW));
+}
+
+TEST_F(MMapFileTest, MapLastNegativePage) {
+ SKIP_IF(!FSSupportsMap());
+ EXPECT_THAT(Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(), -kPageSize),
+ SyscallFailsWithErrno(EOVERFLOW));
}
// MAP_PRIVATE PROT_WRITE is allowed on read-only FDs.
diff --git a/test/syscalls/linux/mq.cc b/test/syscalls/linux/mq.cc
index b3cd6d1b9..36181f149 100644
--- a/test/syscalls/linux/mq.cc
+++ b/test/syscalls/linux/mq.cc
@@ -44,7 +44,9 @@ class PosixQueue {
PosixQueue& operator=(const PosixQueue&) = delete;
// Move constructor.
- PosixQueue(PosixQueue&& q) : fd_(q.fd_), name_(q.name_) {
+ PosixQueue(PosixQueue&& q) {
+ fd_ = q.fd_;
+ name_ = q.name_;
// Call PosixQueue::release, to prevent the object being released from
// unlinking the underlying queue.
q.release();
diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc
index 9515a48f2..c706848fc 100644
--- a/test/syscalls/linux/packet_socket.cc
+++ b/test/syscalls/linux/packet_socket.cc
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor Authors.
+// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,50 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include <arpa/inet.h>
-#include <ifaddrs.h>
-#include <net/ethernet.h>
#include <net/if.h>
-#include <net/if_arp.h>
-#include <netinet/in.h>
-#include <netinet/ip.h>
-#include <netinet/udp.h>
+#include <netinet/if_ether.h>
#include <netpacket/packet.h>
#include <poll.h>
-#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/types.h>
-#include <unistd.h>
+
+#include <limits>
#include "gtest/gtest.h"
-#include "absl/base/internal/endian.h"
#include "test/syscalls/linux/ip_socket_test_util.h"
-#include "test/syscalls/linux/unix_domain_socket_test_util.h"
#include "test/util/capability_util.h"
-#include "test/util/cleanup.h"
#include "test/util/file_descriptor.h"
#include "test/util/socket_util.h"
-#include "test/util/test_util.h"
-
-// Some of these tests involve sending packets via AF_PACKET sockets and the
-// loopback interface. Because AF_PACKET circumvents so much of the networking
-// stack, Linux sees these packets as "martian", i.e. they claim to be to/from
-// localhost but don't have the usual associated data. Thus Linux drops them by
-// default. You can see where this happens by following the code at:
-//
-// - net/ipv4/ip_input.c:ip_rcv_finish, which calls
-// - net/ipv4/route.c:ip_route_input_noref, which calls
-// - net/ipv4/route.c:ip_route_input_slow, which finds and drops martian
-// packets.
-//
-// To tell Linux not to drop these packets, you need to tell it to accept our
-// funny packets (which are completely valid and correct, but lack associated
-// in-kernel data because we use AF_PACKET):
-//
-// echo 1 >> /proc/sys/net/ipv4/conf/lo/accept_local
-// echo 1 >> /proc/sys/net/ipv4/conf/lo/route_localnet
-//
-// These tests require CAP_NET_RAW to run.
namespace gvisor {
namespace testing {
@@ -63,486 +33,257 @@ namespace testing {
namespace {
using ::testing::AnyOf;
+using ::testing::Combine;
using ::testing::Eq;
+using ::testing::Values;
-constexpr char kMessage[] = "soweoneul malhaebwa";
-constexpr in_port_t kPort = 0x409c; // htons(40000)
-
-//
-// "Cooked" tests. Cooked AF_PACKET sockets do not contain link layer
-// headers, and provide link layer destination/source information via a
-// returned struct sockaddr_ll.
-//
-
-// Send kMessage via sock to loopback
-void SendUDPMessage(int sock) {
- struct sockaddr_in dest = {};
- dest.sin_port = kPort;
- dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
- dest.sin_family = AF_INET;
- EXPECT_THAT(sendto(sock, kMessage, sizeof(kMessage), 0,
- reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
- SyscallSucceedsWithValue(sizeof(kMessage)));
-}
-
-// Send an IP packet and make sure ETH_P_<something else> doesn't pick it up.
-TEST(BasicCookedPacketTest, WrongType) {
- if (!ASSERT_NO_ERRNO_AND_VALUE(HavePacketSocketCapability())) {
- ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP),
- SyscallFailsWithErrno(EPERM));
- GTEST_SKIP();
- }
-
- FileDescriptor sock = ASSERT_NO_ERRNO_AND_VALUE(
- Socket(AF_PACKET, SOCK_DGRAM, htons(ETH_P_PUP)));
-
- // Let's use a simple IP payload: a UDP datagram.
- FileDescriptor udp_sock =
- ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
- SendUDPMessage(udp_sock.get());
-
- // Wait and make sure the socket never becomes readable.
- struct pollfd pfd = {};
- pfd.fd = sock.get();
- pfd.events = POLLIN;
- EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0));
-}
-
-// Tests for "cooked" (SOCK_DGRAM) packet(7) sockets.
-class CookedPacketTest : public ::testing::TestWithParam<int> {
+class PacketSocketCreationTest
+ : public ::testing::TestWithParam<std::tuple<int, int>> {
protected:
- // Creates a socket to be used in tests.
- void SetUp() override;
-
- // Closes the socket created by SetUp().
- void TearDown() override;
-
- // Gets the device index of the loopback device.
- int GetLoopbackIndex();
-
- // The socket used for both reading and writing.
- int socket_;
-};
-
-void CookedPacketTest::SetUp() {
- if (!ASSERT_NO_ERRNO_AND_VALUE(HavePacketSocketCapability())) {
- ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())),
- SyscallFailsWithErrno(EPERM));
- GTEST_SKIP();
- }
-
- if (!IsRunningOnGvisor()) {
- FileDescriptor acceptLocal = ASSERT_NO_ERRNO_AND_VALUE(
- Open("/proc/sys/net/ipv4/conf/lo/accept_local", O_RDONLY));
- FileDescriptor routeLocalnet = ASSERT_NO_ERRNO_AND_VALUE(
- Open("/proc/sys/net/ipv4/conf/lo/route_localnet", O_RDONLY));
- char enabled;
- ASSERT_THAT(read(acceptLocal.get(), &enabled, 1),
- SyscallSucceedsWithValue(1));
- ASSERT_EQ(enabled, '1');
- ASSERT_THAT(read(routeLocalnet.get(), &enabled, 1),
- SyscallSucceedsWithValue(1));
- ASSERT_EQ(enabled, '1');
+ void SetUp() override {
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HavePacketSocketCapability())) {
+ const auto [type, protocol] = GetParam();
+ ASSERT_THAT(socket(AF_PACKET, type, htons(protocol)),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP() << "Missing packet socket capability";
+ }
}
+};
- ASSERT_THAT(socket_ = socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())),
- SyscallSucceeds());
+TEST_P(PacketSocketCreationTest, Create) {
+ const auto [type, protocol] = GetParam();
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, type, htons(protocol)));
+ EXPECT_GE(fd.get(), 0);
}
-void CookedPacketTest::TearDown() {
- // TearDown will be run even if we skip the test.
- if (ASSERT_NO_ERRNO_AND_VALUE(HavePacketSocketCapability())) {
- EXPECT_THAT(close(socket_), SyscallSucceeds());
- }
-}
+INSTANTIATE_TEST_SUITE_P(AllPacketSocketTests, PacketSocketCreationTest,
+ Combine(Values(SOCK_DGRAM, SOCK_RAW),
+ Values(0, 1, 255, ETH_P_IP, ETH_P_IPV6,
+ std::numeric_limits<uint16_t>::max())));
-int CookedPacketTest::GetLoopbackIndex() {
- int v = EXPECT_NO_ERRNO_AND_VALUE(gvisor::testing::GetLoopbackIndex());
- EXPECT_NE(v, 0);
- return v;
-}
+class PacketSocketTest : public ::testing::TestWithParam<int> {
+ protected:
+ void SetUp() override {
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HavePacketSocketCapability())) {
+ ASSERT_THAT(socket(AF_PACKET, GetParam(), 0),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP() << "Missing packet socket capability";
+ }
-// Receive and verify the message via packet socket on interface.
-void ReceiveMessage(int sock, int ifindex) {
- // Wait for the socket to become readable.
- struct pollfd pfd = {};
- pfd.fd = sock;
- pfd.events = POLLIN;
- EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1));
-
- // Read and verify the data.
- constexpr size_t packet_size =
- sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage);
- char buf[64];
- struct sockaddr_ll src = {};
- socklen_t src_len = sizeof(src);
- ASSERT_THAT(recvfrom(sock, buf, sizeof(buf), 0,
- reinterpret_cast<struct sockaddr*>(&src), &src_len),
- SyscallSucceedsWithValue(packet_size));
-
- // sockaddr_ll ends with an 8 byte physical address field, but ethernet
- // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2
- // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns
- // sizeof(sockaddr_ll).
- ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2)));
-
- // Verify the source address.
- EXPECT_EQ(src.sll_family, AF_PACKET);
- EXPECT_EQ(src.sll_ifindex, ifindex);
- EXPECT_EQ(src.sll_halen, ETH_ALEN);
- EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP);
- // This came from the loopback device, so the address is all 0s.
- for (int i = 0; i < src.sll_halen; i++) {
- EXPECT_EQ(src.sll_addr[i], 0);
+ socket_ = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, GetParam(), 0));
}
- // Verify the IP header. We memcpy to deal with pointer aligment.
- struct iphdr ip = {};
- memcpy(&ip, buf, sizeof(ip));
- EXPECT_EQ(ip.ihl, 5);
- EXPECT_EQ(ip.version, 4);
- EXPECT_EQ(ip.tot_len, htons(packet_size));
- EXPECT_EQ(ip.protocol, IPPROTO_UDP);
- EXPECT_EQ(ip.daddr, htonl(INADDR_LOOPBACK));
- EXPECT_EQ(ip.saddr, htonl(INADDR_LOOPBACK));
-
- // Verify the UDP header. We memcpy to deal with pointer aligment.
- struct udphdr udp = {};
- memcpy(&udp, buf + sizeof(iphdr), sizeof(udp));
- EXPECT_EQ(udp.dest, kPort);
- EXPECT_EQ(udp.len, htons(sizeof(udphdr) + sizeof(kMessage)));
-
- // Verify the payload.
- char* payload = reinterpret_cast<char*>(buf + sizeof(iphdr) + sizeof(udphdr));
- EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0);
-}
-
-// Receive via a packet socket.
-TEST_P(CookedPacketTest, Receive) {
- // Let's use a simple IP payload: a UDP datagram.
- FileDescriptor udp_sock =
- ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
- SendUDPMessage(udp_sock.get());
-
- // Receive and verify the data.
- int loopback_index = GetLoopbackIndex();
- ReceiveMessage(socket_, loopback_index);
-}
-
-// Send via a packet socket.
-TEST_P(CookedPacketTest, Send) {
- // Let's send a UDP packet and receive it using a regular UDP socket.
- FileDescriptor udp_sock =
- ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
- struct sockaddr_in bind_addr = {};
- bind_addr.sin_family = AF_INET;
- bind_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
- bind_addr.sin_port = kPort;
- ASSERT_THAT(
- bind(udp_sock.get(), reinterpret_cast<struct sockaddr*>(&bind_addr),
- sizeof(bind_addr)),
- SyscallSucceeds());
-
- // Set up the destination physical address.
- struct sockaddr_ll dest = {};
- dest.sll_family = AF_PACKET;
- dest.sll_halen = ETH_ALEN;
- dest.sll_ifindex = GetLoopbackIndex();
- dest.sll_protocol = htons(ETH_P_IP);
- // We're sending to the loopback device, so the address is all 0s.
- memset(dest.sll_addr, 0x00, ETH_ALEN);
-
- // Set up the IP header.
- struct iphdr iphdr = {0};
- iphdr.ihl = 5;
- iphdr.version = 4;
- iphdr.tos = 0;
- iphdr.tot_len =
- htons(sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage));
- // Get a pseudo-random ID. If we clash with an in-use ID the test will fail,
- // but we have no way of getting an ID we know to be good.
- srand(*reinterpret_cast<unsigned int*>(&iphdr));
- iphdr.id = rand();
- // Linux sets this bit ("do not fragment") for small packets.
- iphdr.frag_off = 1 << 6;
- iphdr.ttl = 64;
- iphdr.protocol = IPPROTO_UDP;
- iphdr.daddr = htonl(INADDR_LOOPBACK);
- iphdr.saddr = htonl(INADDR_LOOPBACK);
- iphdr.check = IPChecksum(iphdr);
-
- // Set up the UDP header.
- struct udphdr udphdr = {};
- udphdr.source = kPort;
- udphdr.dest = kPort;
- udphdr.len = htons(sizeof(udphdr) + sizeof(kMessage));
- udphdr.check = UDPChecksum(iphdr, udphdr, kMessage, sizeof(kMessage));
-
- // Copy both headers and the payload into our packet buffer.
- char send_buf[sizeof(iphdr) + sizeof(udphdr) + sizeof(kMessage)];
- memcpy(send_buf, &iphdr, sizeof(iphdr));
- memcpy(send_buf + sizeof(iphdr), &udphdr, sizeof(udphdr));
- memcpy(send_buf + sizeof(iphdr) + sizeof(udphdr), kMessage, sizeof(kMessage));
-
- // Send it.
- ASSERT_THAT(sendto(socket_, send_buf, sizeof(send_buf), 0,
- reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
- SyscallSucceedsWithValue(sizeof(send_buf)));
-
- // Wait for the packet to become available on both sockets.
- struct pollfd pfd = {};
- pfd.fd = udp_sock.get();
- pfd.events = POLLIN;
- ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1));
- pfd.fd = socket_;
- pfd.events = POLLIN;
- ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1));
-
- // Receive on the packet socket.
- char recv_buf[sizeof(send_buf)];
- ASSERT_THAT(recv(socket_, recv_buf, sizeof(recv_buf), 0),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
- ASSERT_EQ(memcmp(recv_buf, send_buf, sizeof(send_buf)), 0);
-
- // Receive on the UDP socket.
- struct sockaddr_in src;
- socklen_t src_len = sizeof(src);
- ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT,
- reinterpret_cast<struct sockaddr*>(&src), &src_len),
- SyscallSucceedsWithValue(sizeof(kMessage)));
- // Check src and payload.
- EXPECT_EQ(strncmp(recv_buf, kMessage, sizeof(kMessage)), 0);
- EXPECT_EQ(src.sin_family, AF_INET);
- EXPECT_EQ(src.sin_port, kPort);
- EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK));
-}
-
-// Bind and receive via packet socket.
-TEST_P(CookedPacketTest, BindReceive) {
- struct sockaddr_ll bind_addr = {};
- bind_addr.sll_family = AF_PACKET;
- bind_addr.sll_protocol = htons(GetParam());
- bind_addr.sll_ifindex = GetLoopbackIndex();
-
- ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
- sizeof(bind_addr)),
- SyscallSucceeds());
-
- // Let's use a simple IP payload: a UDP datagram.
- FileDescriptor udp_sock =
- ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
- SendUDPMessage(udp_sock.get());
-
- // Receive and verify the data.
- ReceiveMessage(socket_, bind_addr.sll_ifindex);
-}
-
-// Double Bind socket.
-TEST_P(CookedPacketTest, DoubleBindSucceeds) {
- struct sockaddr_ll bind_addr = {};
- bind_addr.sll_family = AF_PACKET;
- bind_addr.sll_protocol = htons(GetParam());
- bind_addr.sll_ifindex = GetLoopbackIndex();
-
- ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
- sizeof(bind_addr)),
- SyscallSucceeds());
-
- // Binding socket again should fail.
- ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
- sizeof(bind_addr)),
- // Linux 4.09 returns EINVAL here, but some time before 4.19 it
- // switched to EADDRINUSE.
- SyscallSucceeds());
-}
-
-// Bind and verify we do not receive data on interface which is not bound
-TEST_P(CookedPacketTest, BindDrop) {
- // Let's use a simple IP payload: a UDP datagram.
- FileDescriptor udp_sock =
- ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
-
- struct ifaddrs* if_addr_list = nullptr;
- auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); });
-
- ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds());
+ FileDescriptor socket_;
+};
- // Get interface other than loopback.
- struct ifreq ifr = {};
- for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) {
- if (strcmp(i->ifa_name, "lo") != 0) {
- strncpy(ifr.ifr_name, i->ifa_name, sizeof(ifr.ifr_name));
- break;
+TEST_P(PacketSocketTest, GetSockName) {
+ {
+ // First check the local address of an unbound packet socket.
+ sockaddr_ll addr;
+ socklen_t addrlen = sizeof(addr);
+ ASSERT_THAT(getsockname(socket_.get(), reinterpret_cast<sockaddr*>(&addr),
+ &addrlen),
+ SyscallSucceeds());
+ // sockaddr_ll ends with an 8 byte physical address field, but only the
+ // bytes that are used in the sockaddr_ll.sll_addr field are included in the
+ // address length. Seems Linux used to return the size of sockaddr_ll, but
+ // https://github.com/torvalds/linux/commit/0fb375fb9b93b7d822debc6a734052337ccfdb1f
+ // changed things to only return `sizeof(sockaddr_ll) + sll.sll_addr`.
+ ASSERT_THAT(addrlen, AnyOf(Eq(sizeof(addr)),
+ Eq(sizeof(addr) - sizeof(addr.sll_addr))));
+ EXPECT_EQ(addr.sll_family, AF_PACKET);
+ EXPECT_EQ(addr.sll_ifindex, 0);
+ if (IsRunningOnGvisor() && !IsRunningWithHostinet()) {
+ // TODO(https://gvisor.dev/issue/6530): Do not assume all interfaces have
+ // an ethernet address.
+ EXPECT_EQ(addr.sll_halen, ETH_ALEN);
+ } else {
+ EXPECT_EQ(addr.sll_halen, 0);
}
+ EXPECT_EQ(ntohs(addr.sll_protocol), 0);
+ EXPECT_EQ(addr.sll_hatype, 0);
}
-
- // Skip if no interface is available other than loopback.
- if (strlen(ifr.ifr_name) == 0) {
- GTEST_SKIP();
- }
-
- // Get interface index.
- EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds());
- EXPECT_NE(ifr.ifr_ifindex, 0);
-
- // Bind to packet socket requires only family, protocol and ifindex.
- struct sockaddr_ll bind_addr = {};
- bind_addr.sll_family = AF_PACKET;
- bind_addr.sll_protocol = htons(GetParam());
- bind_addr.sll_ifindex = ifr.ifr_ifindex;
-
- ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ // Next we bind the socket to loopback before checking the local address.
+ const sockaddr_ll bind_addr = {
+ .sll_family = AF_PACKET,
+ .sll_protocol = htons(ETH_P_IP),
+ .sll_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex()),
+ };
+ ASSERT_THAT(bind(socket_.get(), reinterpret_cast<const sockaddr*>(&bind_addr),
sizeof(bind_addr)),
SyscallSucceeds());
-
- // Send to loopback interface.
- struct sockaddr_in dest = {};
- dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
- dest.sin_family = AF_INET;
- dest.sin_port = kPort;
- EXPECT_THAT(sendto(udp_sock.get(), kMessage, sizeof(kMessage), 0,
- reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
- SyscallSucceedsWithValue(sizeof(kMessage)));
-
- // Wait and make sure the socket never receives any data.
- struct pollfd pfd = {};
- pfd.fd = socket_;
- pfd.events = POLLIN;
- EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0));
+ {
+ sockaddr_ll addr;
+ socklen_t addrlen = sizeof(addr);
+ ASSERT_THAT(getsockname(socket_.get(), reinterpret_cast<sockaddr*>(&addr),
+ &addrlen),
+ SyscallSucceeds());
+ ASSERT_THAT(addrlen,
+ AnyOf(Eq(sizeof(addr)),
+ Eq(sizeof(addr) - sizeof(addr.sll_addr) + ETH_ALEN)));
+ EXPECT_EQ(addr.sll_family, AF_PACKET);
+ EXPECT_EQ(addr.sll_ifindex, bind_addr.sll_ifindex);
+ EXPECT_EQ(addr.sll_halen, ETH_ALEN);
+ // Bound to loopback which has the all zeroes address.
+ for (int i = 0; i < addr.sll_halen; ++i) {
+ EXPECT_EQ(addr.sll_addr[i], 0) << "byte mismatch @ idx = " << i;
+ }
+ EXPECT_EQ(ntohs(addr.sll_protocol), htons(addr.sll_protocol));
+ if (IsRunningOnGvisor() && !IsRunningWithHostinet()) {
+ // TODO(https://gvisor.dev/issue/6621): Support populating sll_hatype.
+ EXPECT_EQ(addr.sll_hatype, 0);
+ } else {
+ EXPECT_EQ(addr.sll_hatype, ARPHRD_LOOPBACK);
+ }
+ }
}
-// Verify that we receive outbound packets. This test requires at least one
-// non loopback interface so that we can actually capture an outgoing packet.
-TEST_P(CookedPacketTest, ReceiveOutbound) {
- // Only ETH_P_ALL sockets can receive outbound packets on linux.
- SKIP_IF(GetParam() != ETH_P_ALL);
+TEST_P(PacketSocketTest, RebindProtocol) {
+ const bool kEthHdrIncluded = GetParam() == SOCK_RAW;
+
+ sockaddr_in udp_bind_addr = {
+ .sin_family = AF_INET,
+ .sin_addr = {.s_addr = htonl(INADDR_LOOPBACK)},
+ };
- // Let's use a simple IP payload: a UDP datagram.
FileDescriptor udp_sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
-
- struct ifaddrs* if_addr_list = nullptr;
- auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); });
-
- ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds());
-
- // Get interface other than loopback.
- struct ifreq ifr = {};
- for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) {
- if (strcmp(i->ifa_name, "lo") != 0) {
- strncpy(ifr.ifr_name, i->ifa_name, sizeof(ifr.ifr_name));
- break;
- }
- }
-
- // Skip if no interface is available other than loopback.
- if (strlen(ifr.ifr_name) == 0) {
- GTEST_SKIP();
+ {
+ // Bind the socket so that we have something to send packets to.
+ //
+ // If we didn't do this, the UDP packets we send will be responded to with
+ // ICMP Destination Port Unreachable errors.
+ ASSERT_THAT(
+ bind(udp_sock.get(), reinterpret_cast<const sockaddr*>(&udp_bind_addr),
+ sizeof(udp_bind_addr)),
+ SyscallSucceeds());
+ socklen_t addrlen = sizeof(udp_bind_addr);
+ ASSERT_THAT(
+ getsockname(udp_sock.get(), reinterpret_cast<sockaddr*>(&udp_bind_addr),
+ &addrlen),
+ SyscallSucceeds());
+ ASSERT_THAT(addrlen, sizeof(udp_bind_addr));
}
- // Get interface index and name.
- EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds());
- EXPECT_NE(ifr.ifr_ifindex, 0);
- int ifindex = ifr.ifr_ifindex;
-
- constexpr int kMACSize = 6;
- char hwaddr[kMACSize];
- // Get interface address.
- ASSERT_THAT(ioctl(socket_, SIOCGIFHWADDR, &ifr), SyscallSucceeds());
- ASSERT_THAT(ifr.ifr_hwaddr.sa_family,
- AnyOf(Eq(ARPHRD_NONE), Eq(ARPHRD_ETHER)));
- memcpy(hwaddr, ifr.ifr_hwaddr.sa_data, kMACSize);
-
- // Just send it to the google dns server 8.8.8.8. It's UDP we don't care
- // if it actually gets to the DNS Server we just want to see that we receive
- // it on our AF_PACKET socket.
- //
- // NOTE: We just want to pick an IP that is non-local to avoid having to
- // handle ARP as this should cause the UDP packet to be sent to the default
- // gateway configured for the system under test. Otherwise the only packet we
- // will see is the ARP query unless we picked an IP which will actually
- // resolve. The test is a bit brittle but this was the best compromise for
- // now.
- struct sockaddr_in dest = {};
- ASSERT_EQ(inet_pton(AF_INET, "8.8.8.8", &dest.sin_addr.s_addr), 1);
- dest.sin_family = AF_INET;
- dest.sin_port = kPort;
- EXPECT_THAT(sendto(udp_sock.get(), kMessage, sizeof(kMessage), 0,
- reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
- SyscallSucceedsWithValue(sizeof(kMessage)));
-
- // Wait and make sure the socket receives the data.
- struct pollfd pfd = {};
- pfd.fd = socket_;
- pfd.events = POLLIN;
- EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(1));
-
- // Now read and check that the packet is the one we just sent.
- // Read and verify the data.
- constexpr size_t packet_size =
- sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage);
- char buf[64];
- struct sockaddr_ll src = {};
- socklen_t src_len = sizeof(src);
- ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0,
- reinterpret_cast<struct sockaddr*>(&src), &src_len),
- SyscallSucceedsWithValue(packet_size));
-
- // sockaddr_ll ends with an 8 byte physical address field, but ethernet
- // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2
- // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns
- // sizeof(sockaddr_ll).
- ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2)));
-
- // Verify the source address.
- EXPECT_EQ(src.sll_family, AF_PACKET);
- EXPECT_EQ(src.sll_ifindex, ifindex);
- EXPECT_EQ(src.sll_halen, ETH_ALEN);
- EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP);
- EXPECT_EQ(src.sll_pkttype, PACKET_OUTGOING);
- // Verify the link address of the interface matches that of the non
- // non loopback interface address we stored above.
- for (int i = 0; i < src.sll_halen; i++) {
- EXPECT_EQ(src.sll_addr[i], hwaddr[i]);
- }
+ const int loopback_index = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex());
+
+ auto send_udp_message = [&](const uint64_t v) {
+ ASSERT_THAT(
+ sendto(udp_sock.get(), reinterpret_cast<const char*>(&v), sizeof(v),
+ 0 /* flags */, reinterpret_cast<const sockaddr*>(&udp_bind_addr),
+ sizeof(udp_bind_addr)),
+ SyscallSucceeds());
+ };
+
+ auto bind_to_network_protocol = [&](uint16_t protocol) {
+ const sockaddr_ll packet_bind_addr = {
+ .sll_family = AF_PACKET,
+ .sll_protocol = htons(protocol),
+ .sll_ifindex = loopback_index,
+ };
+
+ ASSERT_THAT(bind(socket_.get(),
+ reinterpret_cast<const sockaddr*>(&packet_bind_addr),
+ sizeof(packet_bind_addr)),
+ SyscallSucceeds());
+ };
+
+ auto test_recv = [&, this](const uint64_t v) {
+ constexpr int kInfiniteTimeout = -1;
+ pollfd pfd = {
+ .fd = socket_.get(),
+ .events = POLLIN,
+ };
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, kInfiniteTimeout),
+ SyscallSucceedsWithValue(1));
- // Verify the IP header.
- struct iphdr ip = {};
- memcpy(&ip, buf, sizeof(ip));
- EXPECT_EQ(ip.ihl, 5);
- EXPECT_EQ(ip.version, 4);
- EXPECT_EQ(ip.tot_len, htons(packet_size));
- EXPECT_EQ(ip.protocol, IPPROTO_UDP);
- EXPECT_EQ(ip.daddr, dest.sin_addr.s_addr);
- EXPECT_NE(ip.saddr, htonl(INADDR_LOOPBACK));
-
- // Verify the UDP header.
- struct udphdr udp = {};
- memcpy(&udp, buf + sizeof(iphdr), sizeof(udp));
- EXPECT_EQ(udp.dest, kPort);
- EXPECT_EQ(udp.len, htons(sizeof(udphdr) + sizeof(kMessage)));
-
- // Verify the payload.
- char* payload = reinterpret_cast<char*>(buf + sizeof(iphdr) + sizeof(udphdr));
- EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0);
-}
+ struct {
+ ethhdr eth;
+ iphdr ip;
+ udphdr udp;
+ uint64_t payload;
+ char unused;
+ } ABSL_ATTRIBUTE_PACKED read_pkt;
+ sockaddr_ll src;
+ socklen_t src_len = sizeof(src);
+
+ char* buf = reinterpret_cast<char*>(&read_pkt);
+ size_t buflen = sizeof(read_pkt);
+ size_t expected_read_len = sizeof(read_pkt) - sizeof(read_pkt.unused);
+ if (!kEthHdrIncluded) {
+ buf += sizeof(read_pkt.eth);
+ buflen -= sizeof(read_pkt.eth);
+ expected_read_len -= sizeof(read_pkt.eth);
+ }
-// Bind with invalid address.
-TEST_P(CookedPacketTest, BindFail) {
- // Null address.
- ASSERT_THAT(
- bind(socket_, nullptr, sizeof(struct sockaddr)),
- AnyOf(SyscallFailsWithErrno(EFAULT), SyscallFailsWithErrno(EINVAL)));
-
- // Address of size 1.
- uint8_t addr = 0;
- ASSERT_THAT(
- bind(socket_, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
- SyscallFailsWithErrno(EINVAL));
+ ASSERT_THAT(recvfrom(socket_.get(), buf, buflen, 0,
+ reinterpret_cast<sockaddr*>(&src), &src_len),
+ SyscallSucceedsWithValue(expected_read_len));
+ // sockaddr_ll ends with an 8 byte physical address field, but ethernet
+ // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2
+ // here, but returns sizeof(sockaddr_ll) since
+ // https://github.com/torvalds/linux/commit/b2cf86e1563e33a14a1c69b3e508d15dc12f804c.
+ ASSERT_THAT(src_len,
+ AnyOf(Eq(sizeof(src)),
+ Eq(sizeof(src) - sizeof(src.sll_addr) + ETH_ALEN)));
+ EXPECT_EQ(src.sll_family, AF_PACKET);
+ EXPECT_EQ(src.sll_ifindex, loopback_index);
+ EXPECT_EQ(src.sll_halen, ETH_ALEN);
+ EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP);
+ // This came from the loopback device, so the address is all 0s.
+ constexpr uint8_t allZeroesMAC[ETH_ALEN] = {};
+ EXPECT_EQ(memcmp(src.sll_addr, allZeroesMAC, sizeof(allZeroesMAC)), 0);
+ if (kEthHdrIncluded) {
+ EXPECT_EQ(memcmp(read_pkt.eth.h_dest, allZeroesMAC, sizeof(allZeroesMAC)),
+ 0);
+ EXPECT_EQ(
+ memcmp(read_pkt.eth.h_source, allZeroesMAC, sizeof(allZeroesMAC)), 0);
+ EXPECT_EQ(ntohs(read_pkt.eth.h_proto), ETH_P_IP);
+ }
+ // IHL hold the size of the header in 4 byte units.
+ EXPECT_EQ(read_pkt.ip.ihl, sizeof(iphdr) / 4);
+ EXPECT_EQ(read_pkt.ip.version, IPVERSION);
+ const uint16_t ip_pkt_size =
+ sizeof(read_pkt) - sizeof(read_pkt.eth) - sizeof(read_pkt.unused);
+ EXPECT_EQ(ntohs(read_pkt.ip.tot_len), ip_pkt_size);
+ EXPECT_EQ(read_pkt.ip.protocol, IPPROTO_UDP);
+ EXPECT_EQ(ntohl(read_pkt.ip.daddr), INADDR_LOOPBACK);
+ EXPECT_EQ(ntohl(read_pkt.ip.saddr), INADDR_LOOPBACK);
+ EXPECT_EQ(read_pkt.udp.source, udp_bind_addr.sin_port);
+ EXPECT_EQ(read_pkt.udp.dest, udp_bind_addr.sin_port);
+ EXPECT_EQ(ntohs(read_pkt.udp.len), ip_pkt_size - sizeof(read_pkt.ip));
+ EXPECT_EQ(read_pkt.payload, v);
+ };
+
+ // The packet socket is not bound to IPv4 so we should not receive the sent
+ // message.
+ uint64_t counter = 0;
+ ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter));
+
+ // Bind to IPv4 and expect to receive the UDP packet we send after binding.
+ ASSERT_NO_FATAL_FAILURE(bind_to_network_protocol(ETH_P_IP));
+ ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter));
+ ASSERT_NO_FATAL_FAILURE(test_recv(counter));
+
+ // Bind the packet socket to a random protocol.
+ ASSERT_NO_FATAL_FAILURE(bind_to_network_protocol(255));
+ ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter));
+
+ // Bind back to IPv4 and expect to the UDP packet we send after binding
+ // back to IPv4.
+ ASSERT_NO_FATAL_FAILURE(bind_to_network_protocol(ETH_P_IP));
+ ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter));
+ ASSERT_NO_FATAL_FAILURE(test_recv(counter));
+
+ // A zero valued protocol number should not change the bound network protocol.
+ ASSERT_NO_FATAL_FAILURE(bind_to_network_protocol(0));
+ ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter));
+ ASSERT_NO_FATAL_FAILURE(test_recv(counter));
}
-INSTANTIATE_TEST_SUITE_P(AllInetTests, CookedPacketTest,
- ::testing::Values(ETH_P_IP, ETH_P_ALL));
+INSTANTIATE_TEST_SUITE_P(AllPacketSocketTests, PacketSocketTest,
+ Values(SOCK_DGRAM, SOCK_RAW));
} // namespace
diff --git a/test/syscalls/linux/packet_socket_dgram.cc b/test/syscalls/linux/packet_socket_dgram.cc
new file mode 100644
index 000000000..9515a48f2
--- /dev/null
+++ b/test/syscalls/linux/packet_socket_dgram.cc
@@ -0,0 +1,550 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <ifaddrs.h>
+#include <net/ethernet.h>
+#include <net/if.h>
+#include <net/if_arp.h>
+#include <netinet/in.h>
+#include <netinet/ip.h>
+#include <netinet/udp.h>
+#include <netpacket/packet.h>
+#include <poll.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/base/internal/endian.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/cleanup.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/socket_util.h"
+#include "test/util/test_util.h"
+
+// Some of these tests involve sending packets via AF_PACKET sockets and the
+// loopback interface. Because AF_PACKET circumvents so much of the networking
+// stack, Linux sees these packets as "martian", i.e. they claim to be to/from
+// localhost but don't have the usual associated data. Thus Linux drops them by
+// default. You can see where this happens by following the code at:
+//
+// - net/ipv4/ip_input.c:ip_rcv_finish, which calls
+// - net/ipv4/route.c:ip_route_input_noref, which calls
+// - net/ipv4/route.c:ip_route_input_slow, which finds and drops martian
+// packets.
+//
+// To tell Linux not to drop these packets, you need to tell it to accept our
+// funny packets (which are completely valid and correct, but lack associated
+// in-kernel data because we use AF_PACKET):
+//
+// echo 1 >> /proc/sys/net/ipv4/conf/lo/accept_local
+// echo 1 >> /proc/sys/net/ipv4/conf/lo/route_localnet
+//
+// These tests require CAP_NET_RAW to run.
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+using ::testing::AnyOf;
+using ::testing::Eq;
+
+constexpr char kMessage[] = "soweoneul malhaebwa";
+constexpr in_port_t kPort = 0x409c; // htons(40000)
+
+//
+// "Cooked" tests. Cooked AF_PACKET sockets do not contain link layer
+// headers, and provide link layer destination/source information via a
+// returned struct sockaddr_ll.
+//
+
+// Send kMessage via sock to loopback
+void SendUDPMessage(int sock) {
+ struct sockaddr_in dest = {};
+ dest.sin_port = kPort;
+ dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ dest.sin_family = AF_INET;
+ EXPECT_THAT(sendto(sock, kMessage, sizeof(kMessage), 0,
+ reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
+ SyscallSucceedsWithValue(sizeof(kMessage)));
+}
+
+// Send an IP packet and make sure ETH_P_<something else> doesn't pick it up.
+TEST(BasicCookedPacketTest, WrongType) {
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HavePacketSocketCapability())) {
+ ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
+
+ FileDescriptor sock = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_PACKET, SOCK_DGRAM, htons(ETH_P_PUP)));
+
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ SendUDPMessage(udp_sock.get());
+
+ // Wait and make sure the socket never becomes readable.
+ struct pollfd pfd = {};
+ pfd.fd = sock.get();
+ pfd.events = POLLIN;
+ EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0));
+}
+
+// Tests for "cooked" (SOCK_DGRAM) packet(7) sockets.
+class CookedPacketTest : public ::testing::TestWithParam<int> {
+ protected:
+ // Creates a socket to be used in tests.
+ void SetUp() override;
+
+ // Closes the socket created by SetUp().
+ void TearDown() override;
+
+ // Gets the device index of the loopback device.
+ int GetLoopbackIndex();
+
+ // The socket used for both reading and writing.
+ int socket_;
+};
+
+void CookedPacketTest::SetUp() {
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HavePacketSocketCapability())) {
+ ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
+
+ if (!IsRunningOnGvisor()) {
+ FileDescriptor acceptLocal = ASSERT_NO_ERRNO_AND_VALUE(
+ Open("/proc/sys/net/ipv4/conf/lo/accept_local", O_RDONLY));
+ FileDescriptor routeLocalnet = ASSERT_NO_ERRNO_AND_VALUE(
+ Open("/proc/sys/net/ipv4/conf/lo/route_localnet", O_RDONLY));
+ char enabled;
+ ASSERT_THAT(read(acceptLocal.get(), &enabled, 1),
+ SyscallSucceedsWithValue(1));
+ ASSERT_EQ(enabled, '1');
+ ASSERT_THAT(read(routeLocalnet.get(), &enabled, 1),
+ SyscallSucceedsWithValue(1));
+ ASSERT_EQ(enabled, '1');
+ }
+
+ ASSERT_THAT(socket_ = socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())),
+ SyscallSucceeds());
+}
+
+void CookedPacketTest::TearDown() {
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HavePacketSocketCapability())) {
+ EXPECT_THAT(close(socket_), SyscallSucceeds());
+ }
+}
+
+int CookedPacketTest::GetLoopbackIndex() {
+ int v = EXPECT_NO_ERRNO_AND_VALUE(gvisor::testing::GetLoopbackIndex());
+ EXPECT_NE(v, 0);
+ return v;
+}
+
+// Receive and verify the message via packet socket on interface.
+void ReceiveMessage(int sock, int ifindex) {
+ // Wait for the socket to become readable.
+ struct pollfd pfd = {};
+ pfd.fd = sock;
+ pfd.events = POLLIN;
+ EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1));
+
+ // Read and verify the data.
+ constexpr size_t packet_size =
+ sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage);
+ char buf[64];
+ struct sockaddr_ll src = {};
+ socklen_t src_len = sizeof(src);
+ ASSERT_THAT(recvfrom(sock, buf, sizeof(buf), 0,
+ reinterpret_cast<struct sockaddr*>(&src), &src_len),
+ SyscallSucceedsWithValue(packet_size));
+
+ // sockaddr_ll ends with an 8 byte physical address field, but ethernet
+ // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2
+ // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns
+ // sizeof(sockaddr_ll).
+ ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2)));
+
+ // Verify the source address.
+ EXPECT_EQ(src.sll_family, AF_PACKET);
+ EXPECT_EQ(src.sll_ifindex, ifindex);
+ EXPECT_EQ(src.sll_halen, ETH_ALEN);
+ EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP);
+ // This came from the loopback device, so the address is all 0s.
+ for (int i = 0; i < src.sll_halen; i++) {
+ EXPECT_EQ(src.sll_addr[i], 0);
+ }
+
+ // Verify the IP header. We memcpy to deal with pointer aligment.
+ struct iphdr ip = {};
+ memcpy(&ip, buf, sizeof(ip));
+ EXPECT_EQ(ip.ihl, 5);
+ EXPECT_EQ(ip.version, 4);
+ EXPECT_EQ(ip.tot_len, htons(packet_size));
+ EXPECT_EQ(ip.protocol, IPPROTO_UDP);
+ EXPECT_EQ(ip.daddr, htonl(INADDR_LOOPBACK));
+ EXPECT_EQ(ip.saddr, htonl(INADDR_LOOPBACK));
+
+ // Verify the UDP header. We memcpy to deal with pointer aligment.
+ struct udphdr udp = {};
+ memcpy(&udp, buf + sizeof(iphdr), sizeof(udp));
+ EXPECT_EQ(udp.dest, kPort);
+ EXPECT_EQ(udp.len, htons(sizeof(udphdr) + sizeof(kMessage)));
+
+ // Verify the payload.
+ char* payload = reinterpret_cast<char*>(buf + sizeof(iphdr) + sizeof(udphdr));
+ EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0);
+}
+
+// Receive via a packet socket.
+TEST_P(CookedPacketTest, Receive) {
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ SendUDPMessage(udp_sock.get());
+
+ // Receive and verify the data.
+ int loopback_index = GetLoopbackIndex();
+ ReceiveMessage(socket_, loopback_index);
+}
+
+// Send via a packet socket.
+TEST_P(CookedPacketTest, Send) {
+ // Let's send a UDP packet and receive it using a regular UDP socket.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ struct sockaddr_in bind_addr = {};
+ bind_addr.sin_family = AF_INET;
+ bind_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ bind_addr.sin_port = kPort;
+ ASSERT_THAT(
+ bind(udp_sock.get(), reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallSucceeds());
+
+ // Set up the destination physical address.
+ struct sockaddr_ll dest = {};
+ dest.sll_family = AF_PACKET;
+ dest.sll_halen = ETH_ALEN;
+ dest.sll_ifindex = GetLoopbackIndex();
+ dest.sll_protocol = htons(ETH_P_IP);
+ // We're sending to the loopback device, so the address is all 0s.
+ memset(dest.sll_addr, 0x00, ETH_ALEN);
+
+ // Set up the IP header.
+ struct iphdr iphdr = {0};
+ iphdr.ihl = 5;
+ iphdr.version = 4;
+ iphdr.tos = 0;
+ iphdr.tot_len =
+ htons(sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage));
+ // Get a pseudo-random ID. If we clash with an in-use ID the test will fail,
+ // but we have no way of getting an ID we know to be good.
+ srand(*reinterpret_cast<unsigned int*>(&iphdr));
+ iphdr.id = rand();
+ // Linux sets this bit ("do not fragment") for small packets.
+ iphdr.frag_off = 1 << 6;
+ iphdr.ttl = 64;
+ iphdr.protocol = IPPROTO_UDP;
+ iphdr.daddr = htonl(INADDR_LOOPBACK);
+ iphdr.saddr = htonl(INADDR_LOOPBACK);
+ iphdr.check = IPChecksum(iphdr);
+
+ // Set up the UDP header.
+ struct udphdr udphdr = {};
+ udphdr.source = kPort;
+ udphdr.dest = kPort;
+ udphdr.len = htons(sizeof(udphdr) + sizeof(kMessage));
+ udphdr.check = UDPChecksum(iphdr, udphdr, kMessage, sizeof(kMessage));
+
+ // Copy both headers and the payload into our packet buffer.
+ char send_buf[sizeof(iphdr) + sizeof(udphdr) + sizeof(kMessage)];
+ memcpy(send_buf, &iphdr, sizeof(iphdr));
+ memcpy(send_buf + sizeof(iphdr), &udphdr, sizeof(udphdr));
+ memcpy(send_buf + sizeof(iphdr) + sizeof(udphdr), kMessage, sizeof(kMessage));
+
+ // Send it.
+ ASSERT_THAT(sendto(socket_, send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Wait for the packet to become available on both sockets.
+ struct pollfd pfd = {};
+ pfd.fd = udp_sock.get();
+ pfd.events = POLLIN;
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1));
+ pfd.fd = socket_;
+ pfd.events = POLLIN;
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1));
+
+ // Receive on the packet socket.
+ char recv_buf[sizeof(send_buf)];
+ ASSERT_THAT(recv(socket_, recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_EQ(memcmp(recv_buf, send_buf, sizeof(send_buf)), 0);
+
+ // Receive on the UDP socket.
+ struct sockaddr_in src;
+ socklen_t src_len = sizeof(src);
+ ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT,
+ reinterpret_cast<struct sockaddr*>(&src), &src_len),
+ SyscallSucceedsWithValue(sizeof(kMessage)));
+ // Check src and payload.
+ EXPECT_EQ(strncmp(recv_buf, kMessage, sizeof(kMessage)), 0);
+ EXPECT_EQ(src.sin_family, AF_INET);
+ EXPECT_EQ(src.sin_port, kPort);
+ EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK));
+}
+
+// Bind and receive via packet socket.
+TEST_P(CookedPacketTest, BindReceive) {
+ struct sockaddr_ll bind_addr = {};
+ bind_addr.sll_family = AF_PACKET;
+ bind_addr.sll_protocol = htons(GetParam());
+ bind_addr.sll_ifindex = GetLoopbackIndex();
+
+ ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallSucceeds());
+
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ SendUDPMessage(udp_sock.get());
+
+ // Receive and verify the data.
+ ReceiveMessage(socket_, bind_addr.sll_ifindex);
+}
+
+// Double Bind socket.
+TEST_P(CookedPacketTest, DoubleBindSucceeds) {
+ struct sockaddr_ll bind_addr = {};
+ bind_addr.sll_family = AF_PACKET;
+ bind_addr.sll_protocol = htons(GetParam());
+ bind_addr.sll_ifindex = GetLoopbackIndex();
+
+ ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallSucceeds());
+
+ // Binding socket again should fail.
+ ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ // Linux 4.09 returns EINVAL here, but some time before 4.19 it
+ // switched to EADDRINUSE.
+ SyscallSucceeds());
+}
+
+// Bind and verify we do not receive data on interface which is not bound
+TEST_P(CookedPacketTest, BindDrop) {
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+
+ struct ifaddrs* if_addr_list = nullptr;
+ auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); });
+
+ ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds());
+
+ // Get interface other than loopback.
+ struct ifreq ifr = {};
+ for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) {
+ if (strcmp(i->ifa_name, "lo") != 0) {
+ strncpy(ifr.ifr_name, i->ifa_name, sizeof(ifr.ifr_name));
+ break;
+ }
+ }
+
+ // Skip if no interface is available other than loopback.
+ if (strlen(ifr.ifr_name) == 0) {
+ GTEST_SKIP();
+ }
+
+ // Get interface index.
+ EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds());
+ EXPECT_NE(ifr.ifr_ifindex, 0);
+
+ // Bind to packet socket requires only family, protocol and ifindex.
+ struct sockaddr_ll bind_addr = {};
+ bind_addr.sll_family = AF_PACKET;
+ bind_addr.sll_protocol = htons(GetParam());
+ bind_addr.sll_ifindex = ifr.ifr_ifindex;
+
+ ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallSucceeds());
+
+ // Send to loopback interface.
+ struct sockaddr_in dest = {};
+ dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ dest.sin_family = AF_INET;
+ dest.sin_port = kPort;
+ EXPECT_THAT(sendto(udp_sock.get(), kMessage, sizeof(kMessage), 0,
+ reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
+ SyscallSucceedsWithValue(sizeof(kMessage)));
+
+ // Wait and make sure the socket never receives any data.
+ struct pollfd pfd = {};
+ pfd.fd = socket_;
+ pfd.events = POLLIN;
+ EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0));
+}
+
+// Verify that we receive outbound packets. This test requires at least one
+// non loopback interface so that we can actually capture an outgoing packet.
+TEST_P(CookedPacketTest, ReceiveOutbound) {
+ // Only ETH_P_ALL sockets can receive outbound packets on linux.
+ SKIP_IF(GetParam() != ETH_P_ALL);
+
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+
+ struct ifaddrs* if_addr_list = nullptr;
+ auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); });
+
+ ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds());
+
+ // Get interface other than loopback.
+ struct ifreq ifr = {};
+ for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) {
+ if (strcmp(i->ifa_name, "lo") != 0) {
+ strncpy(ifr.ifr_name, i->ifa_name, sizeof(ifr.ifr_name));
+ break;
+ }
+ }
+
+ // Skip if no interface is available other than loopback.
+ if (strlen(ifr.ifr_name) == 0) {
+ GTEST_SKIP();
+ }
+
+ // Get interface index and name.
+ EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds());
+ EXPECT_NE(ifr.ifr_ifindex, 0);
+ int ifindex = ifr.ifr_ifindex;
+
+ constexpr int kMACSize = 6;
+ char hwaddr[kMACSize];
+ // Get interface address.
+ ASSERT_THAT(ioctl(socket_, SIOCGIFHWADDR, &ifr), SyscallSucceeds());
+ ASSERT_THAT(ifr.ifr_hwaddr.sa_family,
+ AnyOf(Eq(ARPHRD_NONE), Eq(ARPHRD_ETHER)));
+ memcpy(hwaddr, ifr.ifr_hwaddr.sa_data, kMACSize);
+
+ // Just send it to the google dns server 8.8.8.8. It's UDP we don't care
+ // if it actually gets to the DNS Server we just want to see that we receive
+ // it on our AF_PACKET socket.
+ //
+ // NOTE: We just want to pick an IP that is non-local to avoid having to
+ // handle ARP as this should cause the UDP packet to be sent to the default
+ // gateway configured for the system under test. Otherwise the only packet we
+ // will see is the ARP query unless we picked an IP which will actually
+ // resolve. The test is a bit brittle but this was the best compromise for
+ // now.
+ struct sockaddr_in dest = {};
+ ASSERT_EQ(inet_pton(AF_INET, "8.8.8.8", &dest.sin_addr.s_addr), 1);
+ dest.sin_family = AF_INET;
+ dest.sin_port = kPort;
+ EXPECT_THAT(sendto(udp_sock.get(), kMessage, sizeof(kMessage), 0,
+ reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
+ SyscallSucceedsWithValue(sizeof(kMessage)));
+
+ // Wait and make sure the socket receives the data.
+ struct pollfd pfd = {};
+ pfd.fd = socket_;
+ pfd.events = POLLIN;
+ EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(1));
+
+ // Now read and check that the packet is the one we just sent.
+ // Read and verify the data.
+ constexpr size_t packet_size =
+ sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage);
+ char buf[64];
+ struct sockaddr_ll src = {};
+ socklen_t src_len = sizeof(src);
+ ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0,
+ reinterpret_cast<struct sockaddr*>(&src), &src_len),
+ SyscallSucceedsWithValue(packet_size));
+
+ // sockaddr_ll ends with an 8 byte physical address field, but ethernet
+ // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2
+ // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns
+ // sizeof(sockaddr_ll).
+ ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2)));
+
+ // Verify the source address.
+ EXPECT_EQ(src.sll_family, AF_PACKET);
+ EXPECT_EQ(src.sll_ifindex, ifindex);
+ EXPECT_EQ(src.sll_halen, ETH_ALEN);
+ EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP);
+ EXPECT_EQ(src.sll_pkttype, PACKET_OUTGOING);
+ // Verify the link address of the interface matches that of the non
+ // non loopback interface address we stored above.
+ for (int i = 0; i < src.sll_halen; i++) {
+ EXPECT_EQ(src.sll_addr[i], hwaddr[i]);
+ }
+
+ // Verify the IP header.
+ struct iphdr ip = {};
+ memcpy(&ip, buf, sizeof(ip));
+ EXPECT_EQ(ip.ihl, 5);
+ EXPECT_EQ(ip.version, 4);
+ EXPECT_EQ(ip.tot_len, htons(packet_size));
+ EXPECT_EQ(ip.protocol, IPPROTO_UDP);
+ EXPECT_EQ(ip.daddr, dest.sin_addr.s_addr);
+ EXPECT_NE(ip.saddr, htonl(INADDR_LOOPBACK));
+
+ // Verify the UDP header.
+ struct udphdr udp = {};
+ memcpy(&udp, buf + sizeof(iphdr), sizeof(udp));
+ EXPECT_EQ(udp.dest, kPort);
+ EXPECT_EQ(udp.len, htons(sizeof(udphdr) + sizeof(kMessage)));
+
+ // Verify the payload.
+ char* payload = reinterpret_cast<char*>(buf + sizeof(iphdr) + sizeof(udphdr));
+ EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0);
+}
+
+// Bind with invalid address.
+TEST_P(CookedPacketTest, BindFail) {
+ // Null address.
+ ASSERT_THAT(
+ bind(socket_, nullptr, sizeof(struct sockaddr)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallFailsWithErrno(EINVAL)));
+
+ // Address of size 1.
+ uint8_t addr = 0;
+ ASSERT_THAT(
+ bind(socket_, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+INSTANTIATE_TEST_SUITE_P(AllInetTests, CookedPacketTest,
+ ::testing::Values(ETH_P_IP, ETH_P_ALL));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/ping_socket.cc b/test/syscalls/linux/ping_socket.cc
index 684983a4c..025c9568f 100644
--- a/test/syscalls/linux/ping_socket.cc
+++ b/test/syscalls/linux/ping_socket.cc
@@ -128,9 +128,6 @@ std::vector<std::tuple<SocketKind, BindTestCase>> ICMPTestCases() {
{
.bind_to = V4Broadcast(),
.want = EADDRNOTAVAIL,
- // TODO(gvisor.dev/issue/5711): Remove want_gvisor once ICMP
- // sockets are no longer allowed to bind to broadcast addresses.
- .want_gvisor = 0,
},
{
.bind_to = V4Loopback(),
@@ -139,9 +136,6 @@ std::vector<std::tuple<SocketKind, BindTestCase>> ICMPTestCases() {
{
.bind_to = V4LoopbackSubnetBroadcast(),
.want = EADDRNOTAVAIL,
- // TODO(gvisor.dev/issue/5711): Remove want_gvisor once ICMP
- // sockets are no longer allowed to bind to broadcast addresses.
- .want_gvisor = 0,
},
{
.bind_to = V4Multicast(),
diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc
index 8a4025fed..cb443163c 100644
--- a/test/syscalls/linux/proc.cc
+++ b/test/syscalls/linux/proc.cc
@@ -17,7 +17,6 @@
#include <fcntl.h>
#include <limits.h>
#include <linux/magic.h>
-#include <linux/sem.h>
#include <sched.h>
#include <signal.h>
#include <stddef.h>
@@ -2453,54 +2452,6 @@ TEST(ProcFilesystems, Bug65172365) {
ASSERT_FALSE(proc_filesystems.empty());
}
-TEST(ProcFilesystems, PresenceOfShmMaxMniAll) {
- uint64_t shmmax = 0;
- uint64_t shmall = 0;
- uint64_t shmmni = 0;
- std::string proc_file;
- proc_file = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/kernel/shmmax"));
- ASSERT_FALSE(proc_file.empty());
- ASSERT_TRUE(absl::SimpleAtoi(proc_file, &shmmax));
- proc_file = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/kernel/shmall"));
- ASSERT_FALSE(proc_file.empty());
- ASSERT_TRUE(absl::SimpleAtoi(proc_file, &shmall));
- proc_file = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/kernel/shmmni"));
- ASSERT_FALSE(proc_file.empty());
- ASSERT_TRUE(absl::SimpleAtoi(proc_file, &shmmni));
-
- ASSERT_GT(shmmax, 0);
- ASSERT_GT(shmall, 0);
- ASSERT_GT(shmmni, 0);
- ASSERT_LE(shmall, shmmax);
-
- // These values should never be higher than this by default, for more
- // information see uapi/linux/shm.h
- ASSERT_LE(shmmax, ULONG_MAX - (1UL << 24));
- ASSERT_LE(shmall, ULONG_MAX - (1UL << 24));
-}
-
-TEST(ProcFilesystems, PresenceOfSem) {
- uint32_t semmsl = 0;
- uint32_t semmns = 0;
- uint32_t semopm = 0;
- uint32_t semmni = 0;
- std::string proc_file;
- proc_file = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/kernel/sem"));
- ASSERT_FALSE(proc_file.empty());
- std::vector<absl::string_view> sem_limits =
- absl::StrSplit(proc_file, absl::ByAnyChar("\t"), absl::SkipWhitespace());
- ASSERT_EQ(sem_limits.size(), 4);
- ASSERT_TRUE(absl::SimpleAtoi(sem_limits[0], &semmsl));
- ASSERT_TRUE(absl::SimpleAtoi(sem_limits[1], &semmns));
- ASSERT_TRUE(absl::SimpleAtoi(sem_limits[2], &semopm));
- ASSERT_TRUE(absl::SimpleAtoi(sem_limits[3], &semmni));
-
- ASSERT_EQ(semmsl, SEMMSL);
- ASSERT_EQ(semmns, SEMMNS);
- ASSERT_EQ(semopm, SEMOPM);
- ASSERT_EQ(semmni, SEMMNI);
-}
-
// Check that /proc/mounts is a symlink to self/mounts.
TEST(ProcMounts, IsSymlink) {
auto link = ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/mounts"));
diff --git a/test/syscalls/linux/proc_isolated.cc b/test/syscalls/linux/proc_isolated.cc
new file mode 100644
index 000000000..38d079d2b
--- /dev/null
+++ b/test/syscalls/linux/proc_isolated.cc
@@ -0,0 +1,100 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <linux/msg.h>
+#include <linux/sem.h>
+#include <linux/shm.h>
+
+#include "gtest/gtest.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_split.h"
+#include "test/util/fs_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+TEST(ProcDefaults, PresenceOfShmMaxMniAll) {
+ uint64_t shmmax = 0;
+ uint64_t shmall = 0;
+ uint64_t shmmni = 0;
+ std::string proc_file;
+ proc_file = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/kernel/shmmax"));
+ ASSERT_FALSE(proc_file.empty());
+ ASSERT_TRUE(absl::SimpleAtoi(proc_file, &shmmax));
+ proc_file = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/kernel/shmall"));
+ ASSERT_FALSE(proc_file.empty());
+ ASSERT_TRUE(absl::SimpleAtoi(proc_file, &shmall));
+ proc_file = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/kernel/shmmni"));
+ ASSERT_FALSE(proc_file.empty());
+ ASSERT_TRUE(absl::SimpleAtoi(proc_file, &shmmni));
+
+ ASSERT_EQ(shmmax, SHMMAX);
+ ASSERT_EQ(shmall, SHMALL);
+ ASSERT_EQ(shmmni, SHMMNI);
+ ASSERT_LE(shmall, shmmax);
+
+ // These values should never be higher than this by default, for more
+ // information see uapi/linux/shm.h
+ ASSERT_LE(shmmax, ULONG_MAX - (1UL << 24));
+ ASSERT_LE(shmall, ULONG_MAX - (1UL << 24));
+}
+
+TEST(ProcDefaults, PresenceOfSem) {
+ uint32_t semmsl = 0;
+ uint32_t semmns = 0;
+ uint32_t semopm = 0;
+ uint32_t semmni = 0;
+ std::string proc_file;
+ proc_file = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/kernel/sem"));
+ ASSERT_FALSE(proc_file.empty());
+ std::vector<absl::string_view> sem_limits =
+ absl::StrSplit(proc_file, absl::ByAnyChar("\t"), absl::SkipWhitespace());
+ ASSERT_EQ(sem_limits.size(), 4);
+ ASSERT_TRUE(absl::SimpleAtoi(sem_limits[0], &semmsl));
+ ASSERT_TRUE(absl::SimpleAtoi(sem_limits[1], &semmns));
+ ASSERT_TRUE(absl::SimpleAtoi(sem_limits[2], &semopm));
+ ASSERT_TRUE(absl::SimpleAtoi(sem_limits[3], &semmni));
+
+ ASSERT_EQ(semmsl, SEMMSL);
+ ASSERT_EQ(semmns, SEMMNS);
+ ASSERT_EQ(semopm, SEMOPM);
+ ASSERT_EQ(semmni, SEMMNI);
+}
+
+TEST(ProcDefaults, PresenceOfMsgMniMaxMnb) {
+ uint64_t msgmni = 0;
+ uint64_t msgmax = 0;
+ uint64_t msgmnb = 0;
+
+ std::string proc_file;
+ proc_file = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/kernel/msgmni"));
+ ASSERT_FALSE(proc_file.empty());
+ ASSERT_TRUE(absl::SimpleAtoi(proc_file, &msgmni));
+ proc_file = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/kernel/msgmax"));
+ ASSERT_FALSE(proc_file.empty());
+ ASSERT_TRUE(absl::SimpleAtoi(proc_file, &msgmax));
+ proc_file = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/sys/kernel/msgmnb"));
+ ASSERT_FALSE(proc_file.empty());
+ ASSERT_TRUE(absl::SimpleAtoi(proc_file, &msgmnb));
+
+ ASSERT_EQ(msgmni, MSGMNI);
+ ASSERT_EQ(msgmax, MSGMAX);
+ ASSERT_EQ(msgmnb, MSGMNB);
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc
index ef1db47ee..ef176cbee 100644
--- a/test/syscalls/linux/raw_socket.cc
+++ b/test/syscalls/linux/raw_socket.cc
@@ -25,6 +25,7 @@
#include <algorithm>
#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
#include "test/util/capability_util.h"
#include "test/util/file_descriptor.h"
@@ -39,6 +40,9 @@ namespace testing {
namespace {
+using ::testing::IsNull;
+using ::testing::NotNull;
+
// Fixture for tests parameterized by protocol.
class RawSocketTest : public ::testing::TestWithParam<std::tuple<int, int>> {
protected:
@@ -1057,6 +1061,131 @@ TEST(RawSocketTest, BindReceive) {
ASSERT_NO_FATAL_FAILURE(TestRawSocketMaybeBindReceive(true /* do_bind */));
}
+TEST(RawSocketTest, ReceiveIPPacketInfo) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveRawIPSocketCapability()));
+
+ FileDescriptor raw =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP));
+
+ const sockaddr_in addr_ = {
+ .sin_family = AF_INET,
+ .sin_addr = {.s_addr = htonl(INADDR_LOOPBACK)},
+ };
+ ASSERT_THAT(
+ bind(raw.get(), reinterpret_cast<const sockaddr*>(&addr_), sizeof(addr_)),
+ SyscallSucceeds());
+
+ // Register to receive IP packet info.
+ constexpr int one = 1;
+ ASSERT_THAT(setsockopt(raw.get(), IPPROTO_IP, IP_PKTINFO, &one, sizeof(one)),
+ SyscallSucceeds());
+
+ constexpr char send_buf[] = "malformed UDP";
+ ASSERT_THAT(sendto(raw.get(), send_buf, sizeof(send_buf), 0 /* flags */,
+ reinterpret_cast<const sockaddr*>(&addr_), sizeof(addr_)),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ struct {
+ iphdr ip;
+ char data[sizeof(send_buf)];
+
+ // Extra space in the receive buffer should be unused.
+ char unused_space;
+ } ABSL_ATTRIBUTE_PACKED recv_buf;
+ iovec recv_iov = {
+ .iov_base = &recv_buf,
+ .iov_len = sizeof(recv_buf),
+ };
+ in_pktinfo received_pktinfo;
+ char recv_cmsg_buf[CMSG_SPACE(sizeof(received_pktinfo))];
+ msghdr recv_msg = {
+ .msg_iov = &recv_iov,
+ .msg_iovlen = 1,
+ .msg_control = recv_cmsg_buf,
+ .msg_controllen = CMSG_LEN(sizeof(received_pktinfo)),
+ };
+ ASSERT_THAT(RetryEINTR(recvmsg)(raw.get(), &recv_msg, 0),
+ SyscallSucceedsWithValue(sizeof(iphdr) + sizeof(send_buf)));
+ EXPECT_EQ(memcmp(send_buf, &recv_buf.data, sizeof(send_buf)), 0);
+ EXPECT_EQ(recv_buf.ip.version, static_cast<unsigned int>(IPVERSION));
+ // IHL holds the number of header bytes in 4 byte units.
+ EXPECT_EQ(recv_buf.ip.ihl, sizeof(iphdr) / 4);
+ EXPECT_EQ(ntohs(recv_buf.ip.tot_len), sizeof(iphdr) + sizeof(send_buf));
+ EXPECT_EQ(recv_buf.ip.protocol, IPPROTO_UDP);
+ EXPECT_EQ(ntohl(recv_buf.ip.saddr), INADDR_LOOPBACK);
+ EXPECT_EQ(ntohl(recv_buf.ip.daddr), INADDR_LOOPBACK);
+
+ cmsghdr* cmsg = CMSG_FIRSTHDR(&recv_msg);
+ ASSERT_THAT(cmsg, NotNull());
+ EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(received_pktinfo)));
+ EXPECT_EQ(cmsg->cmsg_level, IPPROTO_IP);
+ EXPECT_EQ(cmsg->cmsg_type, IP_PKTINFO);
+ memcpy(&received_pktinfo, CMSG_DATA(cmsg), sizeof(received_pktinfo));
+ EXPECT_EQ(received_pktinfo.ipi_ifindex,
+ ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex()));
+ EXPECT_EQ(ntohl(received_pktinfo.ipi_spec_dst.s_addr), INADDR_LOOPBACK);
+ EXPECT_EQ(ntohl(received_pktinfo.ipi_addr.s_addr), INADDR_LOOPBACK);
+
+ EXPECT_THAT(CMSG_NXTHDR(&recv_msg, cmsg), IsNull());
+}
+
+TEST(RawSocketTest, ReceiveIPv6PacketInfo) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveRawIPSocketCapability()));
+
+ FileDescriptor raw =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_RAW, IPPROTO_UDP));
+
+ const sockaddr_in6 addr_ = {
+ .sin6_family = AF_INET6,
+ .sin6_addr = in6addr_loopback,
+ };
+ ASSERT_THAT(
+ bind(raw.get(), reinterpret_cast<const sockaddr*>(&addr_), sizeof(addr_)),
+ SyscallSucceeds());
+
+ // Register to receive IPv6 packet info.
+ constexpr int one = 1;
+ ASSERT_THAT(
+ setsockopt(raw.get(), IPPROTO_IPV6, IPV6_RECVPKTINFO, &one, sizeof(one)),
+ SyscallSucceeds());
+
+ constexpr char send_buf[] = "malformed UDP";
+ ASSERT_THAT(sendto(raw.get(), send_buf, sizeof(send_buf), 0 /* flags */,
+ reinterpret_cast<const sockaddr*>(&addr_), sizeof(addr_)),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ char recv_buf[sizeof(send_buf) + 1];
+ iovec recv_iov = {
+ .iov_base = recv_buf,
+ .iov_len = sizeof(recv_buf),
+ };
+ in6_pktinfo received_pktinfo;
+ char recv_cmsg_buf[CMSG_SPACE(sizeof(received_pktinfo))];
+ msghdr recv_msg = {
+ .msg_iov = &recv_iov,
+ .msg_iovlen = 1,
+ .msg_control = recv_cmsg_buf,
+ .msg_controllen = CMSG_LEN(sizeof(received_pktinfo)),
+ };
+ ASSERT_THAT(RetryEINTR(recvmsg)(raw.get(), &recv_msg, 0),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+ EXPECT_EQ(memcmp(send_buf, recv_buf, sizeof(send_buf)), 0);
+
+ cmsghdr* cmsg = CMSG_FIRSTHDR(&recv_msg);
+ ASSERT_THAT(cmsg, NotNull());
+ EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(received_pktinfo)));
+ EXPECT_EQ(cmsg->cmsg_level, IPPROTO_IPV6);
+ EXPECT_EQ(cmsg->cmsg_type, IPV6_PKTINFO);
+ memcpy(&received_pktinfo, CMSG_DATA(cmsg), sizeof(received_pktinfo));
+ EXPECT_EQ(received_pktinfo.ipi6_ifindex,
+ ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex()));
+ ASSERT_EQ(memcmp(&received_pktinfo.ipi6_addr, &in6addr_loopback,
+ sizeof(in6addr_loopback)),
+ 0);
+
+ EXPECT_THAT(CMSG_NXTHDR(&recv_msg, cmsg), IsNull());
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc
index 8a87f2667..2f45bb5ad 100644
--- a/test/syscalls/linux/socket_ip_udp_generic.cc
+++ b/test/syscalls/linux/socket_ip_udp_generic.cc
@@ -493,7 +493,7 @@ TEST_P(UDPSocketPairTest, TClassRecvMismatch) {
// This should only test AF_INET6 sockets for the mismatch behavior.
SKIP_IF(GetParam().domain != AF_INET6);
// IPV6_RECVTCLASS is only valid for SOCK_DGRAM and SOCK_RAW.
- SKIP_IF((GetParam().type != SOCK_DGRAM) | (GetParam().type != SOCK_RAW));
+ SKIP_IF((GetParam().type != SOCK_DGRAM) || (GetParam().type != SOCK_RAW));
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
diff --git a/test/syscalls/linux/socket_ip_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ip_udp_unbound_external_networking.cc
deleted file mode 100644
index af2459a2f..000000000
--- a/test/syscalls/linux/socket_ip_udp_unbound_external_networking.cc
+++ /dev/null
@@ -1,59 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "test/syscalls/linux/socket_ip_udp_unbound_external_networking.h"
-
-#include "test/util/socket_util.h"
-#include "test/util/test_util.h"
-
-namespace gvisor {
-namespace testing {
-
-void IPUDPUnboundExternalNetworkingSocketTest::SetUp() {
- // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its
- // IPv4 address on eth0.
- found_net_interfaces_ = false;
-
- // Get interface list.
- ASSERT_NO_ERRNO(if_helper_.Load());
- std::vector<std::string> if_names = if_helper_.InterfaceList(AF_INET);
- if (if_names.size() != 2) {
- return;
- }
-
- // Figure out which interface is where.
- std::string lo = if_names[0];
- std::string eth = if_names[1];
- if (lo != "lo") std::swap(lo, eth);
- if (lo != "lo") return;
-
- lo_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(lo));
- auto lo_if_addr = if_helper_.GetAddr(AF_INET, lo);
- if (lo_if_addr == nullptr) {
- return;
- }
- lo_if_addr_ = *reinterpret_cast<const sockaddr_in*>(lo_if_addr);
-
- eth_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(eth));
- auto eth_if_addr = if_helper_.GetAddr(AF_INET, eth);
- if (eth_if_addr == nullptr) {
- return;
- }
- eth_if_addr_ = *reinterpret_cast<const sockaddr_in*>(eth_if_addr);
-
- found_net_interfaces_ = true;
-}
-
-} // namespace testing
-} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_udp_unbound_external_networking.h b/test/syscalls/linux/socket_ip_udp_unbound_external_networking.h
index 2e8aab129..92c20eba9 100644
--- a/test/syscalls/linux/socket_ip_udp_unbound_external_networking.h
+++ b/test/syscalls/linux/socket_ip_udp_unbound_external_networking.h
@@ -16,29 +16,13 @@
#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IP_UDP_UNBOUND_EXTERNAL_NETWORKING_H_
#include "test/syscalls/linux/ip_socket_test_util.h"
-#include "test/util/socket_util.h"
namespace gvisor {
namespace testing {
// Test fixture for tests that apply to unbound IP UDP sockets in a sandbox
// with external networking support.
-class IPUDPUnboundExternalNetworkingSocketTest : public SimpleSocketTest {
- protected:
- void SetUp() override;
-
- IfAddrHelper if_helper_;
-
- // found_net_interfaces_ is set to false if SetUp() could not obtain
- // all interface infos that we need.
- bool found_net_interfaces_;
-
- // Interface infos.
- int lo_if_idx_;
- int eth_if_idx_;
- sockaddr_in lo_if_addr_;
- sockaddr_in eth_if_addr_;
-};
+class IPUDPUnboundExternalNetworkingSocketTest : public SimpleSocketTest {};
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound_loopback.cc b/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound_loopback.cc
index 9ae68e015..4ab4ef261 100644
--- a/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound_loopback.cc
+++ b/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound_loopback.cc
@@ -18,11 +18,11 @@
namespace gvisor {
namespace testing {
-INSTANTIATE_TEST_SUITE_P(IPv4Sockets, IPv4DatagramBasedUnboundSocketTest,
- ::testing::ValuesIn(ApplyVecToVec<SocketKind>(
- {IPv4UDPUnboundSocket, IPv4RawUDPUnboundSocket},
- AllBitwiseCombinations(List<int>{
- 0, SOCK_NONBLOCK}))));
+INSTANTIATE_TEST_SUITE_P(
+ IPv4Sockets, IPv4DatagramBasedUnboundSocketTest,
+ ::testing::ValuesIn(ApplyVecToVec<SocketKind>(
+ {IPv4UDPUnboundSocket, IPv4RawUDPUnboundSocket, ICMPUnboundSocket},
+ AllBitwiseCombinations(List<int>{0, SOCK_NONBLOCK}))));
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
index c6e775b2a..6c67ec51e 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
@@ -14,9 +14,62 @@
#include "test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h"
+#include <net/if.h>
+
+#include "absl/cleanup/cleanup.h"
+#include "test/util/socket_util.h"
+#include "test/util/test_util.h"
+
namespace gvisor {
namespace testing {
+void IPv4UDPUnboundExternalNetworkingSocketTest::SetUp() {
+#ifdef ANDROID
+ GTEST_SKIP() << "Android does not support getifaddrs in r22";
+#endif
+
+ ifaddrs* ifaddr;
+ ASSERT_THAT(getifaddrs(&ifaddr), SyscallSucceeds());
+ auto cleanup = absl::MakeCleanup([ifaddr] { freeifaddrs(ifaddr); });
+
+ for (const ifaddrs* ifa = ifaddr; ifa != nullptr; ifa = ifa->ifa_next) {
+ ASSERT_NE(ifa->ifa_name, nullptr);
+ ASSERT_NE(ifa->ifa_addr, nullptr);
+
+ if (ifa->ifa_addr->sa_family != AF_INET) {
+ continue;
+ }
+
+ std::optional<std::pair<int, sockaddr_in>>& if_pair = *[this, ifa]() {
+ if (strcmp(ifa->ifa_name, "lo") == 0) {
+ return &lo_if_;
+ }
+ return &eth_if_;
+ }();
+
+ const int if_index =
+ ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex(ifa->ifa_name));
+
+ std::cout << " name=" << ifa->ifa_name
+ << " addr=" << GetAddrStr(ifa->ifa_addr) << " index=" << if_index
+ << " has_value=" << if_pair.has_value() << std::endl;
+
+ if (if_pair.has_value()) {
+ continue;
+ }
+
+ if_pair = std::make_pair(
+ if_index, *reinterpret_cast<const sockaddr_in*>(ifa->ifa_addr));
+ }
+
+ if (!(eth_if_.has_value() && lo_if_.has_value())) {
+ // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its
+ // IPv4 address on eth0.
+ GTEST_SKIP() << " eth_if_.has_value()=" << eth_if_.has_value()
+ << " lo_if_.has_value()=" << lo_if_.has_value();
+ }
+}
+
TestAddress V4EmptyAddress() {
TestAddress t("V4Empty");
t.addr.ss_family = AF_INET;
@@ -28,7 +81,6 @@ TestAddress V4EmptyAddress() {
// the destination port number.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
UDPBroadcastReceivedOnExpectedPort) {
- SKIP_IF(!found_net_interfaces_);
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto rcvr1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto rcvr2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -101,8 +153,6 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// not a unicast address.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
UDPBroadcastReceivedOnExpectedAddresses) {
- SKIP_IF(!found_net_interfaces_);
-
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto rcvr1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto rcvr2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -149,7 +199,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// Bind the non-receiving socket to the unicast ethernet address.
auto norecv_addr = rcv1_addr;
reinterpret_cast<sockaddr_in*>(&norecv_addr.addr)->sin_addr =
- eth_if_addr_.sin_addr;
+ eth_if_addr().sin_addr;
ASSERT_THAT(
bind(norcv->get(), AsSockAddr(&norecv_addr.addr), norecv_addr.addr_len),
SyscallSucceedsWithValue(0));
@@ -184,7 +234,6 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// (UDPBroadcastSendRecvOnSocketBoundToAny).
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
UDPBroadcastSendRecvOnSocketBoundToBroadcast) {
- SKIP_IF(!found_net_interfaces_);
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Enable SO_BROADCAST.
@@ -224,7 +273,6 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// (UDPBroadcastSendRecvOnSocketBoundToBroadcast).
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
UDPBroadcastSendRecvOnSocketBoundToAny) {
- SKIP_IF(!found_net_interfaces_);
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Enable SO_BROADCAST.
@@ -261,7 +309,6 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// Verifies that a UDP broadcast fails to send on a socket with SO_BROADCAST
// disabled.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendBroadcast) {
- SKIP_IF(!found_net_interfaces_);
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Broadcast a test message without having enabled SO_BROADCAST on the sending
@@ -306,12 +353,6 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendUnicastOnUnbound) {
// set interface or group membership.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
TestSendMulticastSelfNoGroup) {
- // FIXME(b/125485338): A group membership is not required for external
- // multicast on gVisor.
- SKIP_IF(IsRunningOnGvisor());
-
- SKIP_IF(!found_net_interfaces_);
-
auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto bind_addr = V4Any();
@@ -345,7 +386,6 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// Check that multicast packets will be delivered to the sending socket without
// setting an interface.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastSelf) {
- SKIP_IF(!found_net_interfaces_);
auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto bind_addr = V4Any();
@@ -388,7 +428,6 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastSelf) {
// set interface and IP_MULTICAST_LOOP disabled.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
TestSendMulticastSelfLoopOff) {
- SKIP_IF(!found_net_interfaces_);
auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto bind_addr = V4Any();
@@ -434,12 +473,6 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// Check that multicast packets won't be delivered to another socket with no
// set interface or group membership.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastNoGroup) {
- // FIXME(b/125485338): A group membership is not required for external
- // multicast on gVisor.
- SKIP_IF(IsRunningOnGvisor());
-
- SKIP_IF(!found_net_interfaces_);
-
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -476,7 +509,6 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastNoGroup) {
// Check that multicast packets will be delivered to another socket without
// setting an interface.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticast) {
- SKIP_IF(!found_net_interfaces_);
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -522,7 +554,6 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticast) {
// set interface and IP_MULTICAST_LOOP disabled on the sending socket.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
TestSendMulticastSenderNoLoop) {
- SKIP_IF(!found_net_interfaces_);
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -572,8 +603,6 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// setting an interface and IP_MULTICAST_LOOP disabled on the receiving socket.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
TestSendMulticastReceiverNoLoop) {
- SKIP_IF(!found_net_interfaces_);
-
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -624,7 +653,6 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// and both will receive data on it when bound to the ANY address.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
TestSendMulticastToTwoBoundToAny) {
- SKIP_IF(!found_net_interfaces_);
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
std::unique_ptr<FileDescriptor> receivers[2] = {
ASSERT_NO_ERRNO_AND_VALUE(NewSocket()),
@@ -689,7 +717,6 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// and both will receive data on it when bound to the multicast address.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
TestSendMulticastToTwoBoundToMulticastAddress) {
- SKIP_IF(!found_net_interfaces_);
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
std::unique_ptr<FileDescriptor> receivers[2] = {
ASSERT_NO_ERRNO_AND_VALUE(NewSocket()),
@@ -757,7 +784,6 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// multicast address, both will receive data.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
TestSendMulticastToTwoBoundToAnyAndMulticastAddress) {
- SKIP_IF(!found_net_interfaces_);
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
std::unique_ptr<FileDescriptor> receivers[2] = {
ASSERT_NO_ERRNO_AND_VALUE(NewSocket()),
@@ -829,8 +855,6 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// is not a multicast address.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
IpMulticastLoopbackFromAddr) {
- SKIP_IF(!found_net_interfaces_);
-
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -893,8 +917,6 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// interface, a multicast packet sent out uses the latter as its source address.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
IpMulticastLoopbackIfNicAndAddr) {
- SKIP_IF(!found_net_interfaces_);
-
// Create receiver, bind to ANY and join the multicast group.
auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto receiver_addr = V4Any();
@@ -906,11 +928,15 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
- int receiver_port =
+ const in_port_t receiver_port =
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
- ip_mreqn group = {};
- group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
- group.imr_ifindex = lo_if_idx_;
+ const ip_mreqn group = {
+ .imr_multiaddr =
+ {
+ .s_addr = inet_addr(kMulticastAddress),
+ },
+ .imr_ifindex = lo_if_idx(),
+ };
ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
sizeof(group)),
SyscallSucceeds());
@@ -918,9 +944,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// Set outgoing multicast interface config, with NIC and addr pointing to
// different interfaces.
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- ip_mreqn iface = {};
- iface.imr_ifindex = lo_if_idx_;
- iface.imr_address = eth_if_addr_.sin_addr;
+ const ip_mreqn iface = {
+ .imr_address = eth_if_addr().sin_addr,
+ .imr_ifindex = lo_if_idx(),
+ };
ASSERT_THAT(setsockopt(sender->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
sizeof(iface)),
SyscallSucceeds());
@@ -928,67 +955,104 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// Send a multicast packet.
auto sendto_addr = V4Multicast();
reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = receiver_port;
- char send_buf[4] = {};
+ char send_buf[4];
ASSERT_THAT(
RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
AsSockAddr(&sendto_addr.addr), sendto_addr.addr_len),
SyscallSucceedsWithValue(sizeof(send_buf)));
// Receive a multicast packet.
- char recv_buf[sizeof(send_buf)] = {};
+ char recv_buf[sizeof(send_buf) + 1];
auto src_addr = V4EmptyAddress();
ASSERT_THAT(
RetryEINTR(recvfrom)(receiver->get(), recv_buf, sizeof(recv_buf), 0,
AsSockAddr(&src_addr.addr), &src_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
- ASSERT_EQ(sizeof(struct sockaddr_in), src_addr.addr_len);
- sockaddr_in* src_addr_in = reinterpret_cast<sockaddr_in*>(&src_addr.addr);
-
- // FIXME (b/137781162): When sending a multicast packet use the proper logic
- // to determine the packet's src-IP.
- SKIP_IF(IsRunningOnGvisor());
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_EQ(src_addr.addr_len, sizeof(struct sockaddr_in));
// Verify the received source address.
- EXPECT_EQ(eth_if_addr_.sin_addr.s_addr, src_addr_in->sin_addr.s_addr);
+ //
+ // TODO(https://gvisor.dev/issue/6686): gVisor is a strong host, preventing
+ // the packet from being sent from the loopback device using the ethernet
+ // device's address.
+ if (IsRunningOnGvisor()) {
+ EXPECT_EQ(GetAddrStr(AsSockAddr(&src_addr.addr)),
+ GetAddr4Str(&lo_if_addr().sin_addr));
+ } else {
+ EXPECT_EQ(GetAddrStr(AsSockAddr(&src_addr.addr)),
+ GetAddr4Str(&eth_if_addr().sin_addr));
+ }
}
// Check that when we are bound to one interface we can set IP_MULTICAST_IF to
// another interface.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
IpMulticastLoopbackBindToOneIfSetMcastIfToAnother) {
- SKIP_IF(!found_net_interfaces_);
-
- // FIXME (b/137790511): When bound to one interface it is not possible to set
- // IP_MULTICAST_IF to a different interface.
- SKIP_IF(IsRunningOnGvisor());
-
- // Create sender and bind to eth interface.
- auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- ASSERT_THAT(
- bind(sender->get(), AsSockAddr(&eth_if_addr_), sizeof(eth_if_addr_)),
- SyscallSucceeds());
-
// Run through all possible combinations of index and address for
// IP_MULTICAST_IF that selects the loopback interface.
- struct {
- int imr_ifindex;
- struct in_addr imr_address;
- } test_data[] = {
- {lo_if_idx_, {}},
- {0, lo_if_addr_.sin_addr},
- {lo_if_idx_, lo_if_addr_.sin_addr},
- {lo_if_idx_, eth_if_addr_.sin_addr},
+ ip_mreqn ifaces[] = {
+ {
+ .imr_address = {},
+ .imr_ifindex = lo_if_idx(),
+ },
+ {
+ .imr_address = lo_if_addr().sin_addr,
+ .imr_ifindex = 0,
+ },
+ {
+ .imr_address = lo_if_addr().sin_addr,
+ .imr_ifindex = lo_if_idx(),
+ },
+ {
+ .imr_address = eth_if_addr().sin_addr,
+ .imr_ifindex = lo_if_idx(),
+ },
};
- for (auto t : test_data) {
- ip_mreqn iface = {};
- iface.imr_ifindex = t.imr_ifindex;
- iface.imr_address = t.imr_address;
- EXPECT_THAT(setsockopt(sender->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
- sizeof(iface)),
- SyscallSucceeds())
- << "imr_index=" << iface.imr_ifindex
- << " imr_address=" << GetAddr4Str(&iface.imr_address);
+
+ {
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ ASSERT_THAT(
+ bind(sender->get(), AsSockAddr(&eth_if_addr()), sizeof(eth_if_addr())),
+ SyscallSucceeds());
+
+ for (const ip_mreqn& iface : ifaces) {
+ EXPECT_THAT(setsockopt(sender->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
+ SyscallSucceeds())
+ << " imr_index=" << iface.imr_ifindex
+ << " imr_address=" << GetAddr4Str(&iface.imr_address);
+ }
+ }
+
+ {
+ char eth_if_name[IF_NAMESIZE];
+ memset(eth_if_name, 0xAA, sizeof(eth_if_name));
+ ASSERT_NE(if_indextoname(eth_if_idx(), eth_if_name), nullptr)
+ << strerror(errno);
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ ASSERT_THAT(setsockopt(sender->get(), SOL_SOCKET, SO_BINDTODEVICE,
+ eth_if_name, sizeof(eth_if_name)),
+ SyscallSucceeds());
+
+ for (const ip_mreqn& iface : ifaces) {
+ // FIXME(b/137790511): Disallow mismatching IP_MULTICAST_IF and
+ // SO_BINDTODEVICE.
+ if (IsRunningOnGvisor()) {
+ EXPECT_THAT(setsockopt(sender->get(), IPPROTO_IP, IP_MULTICAST_IF,
+ &iface, sizeof(iface)),
+ SyscallSucceeds())
+ << " imr_index=" << iface.imr_ifindex
+ << " imr_address=" << GetAddr4Str(&iface.imr_address);
+ } else {
+ EXPECT_THAT(setsockopt(sender->get(), IPPROTO_IP, IP_MULTICAST_IF,
+ &iface, sizeof(iface)),
+ SyscallFailsWithErrno(EINVAL))
+ << " imr_index=" << iface.imr_ifindex
+ << " imr_address=" << GetAddr4Str(&iface.imr_address);
+ }
+ }
}
}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h
index 20922ac1f..ac917b32c 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h
@@ -22,8 +22,22 @@ namespace testing {
// Test fixture for tests that apply to unbound IPv4 UDP sockets in a sandbox
// with external networking support.
-using IPv4UDPUnboundExternalNetworkingSocketTest =
- IPUDPUnboundExternalNetworkingSocketTest;
+class IPv4UDPUnboundExternalNetworkingSocketTest
+ : public IPUDPUnboundExternalNetworkingSocketTest {
+ protected:
+ void SetUp() override;
+
+ int lo_if_idx() const { return std::get<0>(lo_if_.value()); }
+ int eth_if_idx() const { return std::get<0>(eth_if_.value()); }
+
+ const sockaddr_in& lo_if_addr() const { return std::get<1>(lo_if_.value()); }
+ const sockaddr_in& eth_if_addr() const {
+ return std::get<1>(eth_if_.value());
+ }
+
+ private:
+ std::optional<std::pair<int, sockaddr_in>> lo_if_, eth_if_;
+};
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound.cc b/test/syscalls/linux/socket_ipv6_udp_unbound.cc
index 612fd531c..c431bf771 100644
--- a/test/syscalls/linux/socket_ipv6_udp_unbound.cc
+++ b/test/syscalls/linux/socket_ipv6_udp_unbound.cc
@@ -39,6 +39,86 @@
namespace gvisor {
namespace testing {
+using ::testing::IsNull;
+using ::testing::NotNull;
+
+TEST_P(IPv6UDPUnboundSocketTest, IPv6PacketInfo) {
+ auto sender_socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto receiver_socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ auto sender_addr = V6Loopback();
+ ASSERT_THAT(bind(sender_socket->get(), AsSockAddr(&sender_addr.addr),
+ sender_addr.addr_len),
+ SyscallSucceeds());
+
+ auto receiver_addr = V6Loopback();
+ ASSERT_THAT(bind(receiver_socket->get(), AsSockAddr(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(receiver_socket->get(),
+ AsSockAddr(&receiver_addr.addr), &receiver_addr_len),
+ SyscallSucceeds());
+ ASSERT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Make sure we get IPv6 packet information as control messages.
+ constexpr int one = 1;
+ ASSERT_THAT(setsockopt(receiver_socket->get(), IPPROTO_IPV6, IPV6_RECVPKTINFO,
+ &one, sizeof(one)),
+ SyscallSucceeds());
+
+ // Send a packet - we don't care about the packet itself, only the returned
+ // IPV6_PKTINFO control message.
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(RetryEINTR(sendto)(
+ sender_socket->get(), send_buf, sizeof(send_buf), 0,
+ AsSockAddr(&receiver_addr.addr), receiver_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the packet with the packet information control
+ // message.
+ char recv_buf[sizeof(send_buf) + 1];
+ in6_pktinfo received_pktinfo;
+ char recv_cmsg_buf[CMSG_SPACE(sizeof(received_pktinfo))];
+ iovec recv_iov = {
+ .iov_base = recv_buf,
+ .iov_len = sizeof(recv_buf),
+ };
+ msghdr recv_msg = {
+ .msg_iov = &recv_iov,
+ .msg_iovlen = 1,
+ .msg_control = recv_cmsg_buf,
+ .msg_controllen = sizeof(recv_cmsg_buf),
+ };
+ ASSERT_THAT(
+ RetryEINTR(recvmsg)(receiver_socket->get(), &recv_msg, 0 /* flags */),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+ EXPECT_EQ(memcmp(send_buf, recv_buf, sizeof(send_buf)), 0);
+
+ cmsghdr* cmsg = CMSG_FIRSTHDR(&recv_msg);
+ ASSERT_THAT(cmsg, NotNull());
+ EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(in6_pktinfo)));
+ EXPECT_EQ(cmsg->cmsg_level, IPPROTO_IPV6);
+ EXPECT_EQ(cmsg->cmsg_type, IPV6_PKTINFO);
+ // As per cmsg(3) (https://www.man7.org/linux/man-pages/man3/cmsg.3.html),
+ //
+ // CMSG_DATA() returns a pointer to the data portion of a cmsghdr. The
+ // pointer returned cannot be assumed to be suitably aligned for accessing
+ // arbitrary payload data types. Applications should not cast it to a
+ // pointer type matching the payload, but should instead use memcpy(3) to
+ // copy data to or from a suitably declared object.
+ memcpy(&received_pktinfo, CMSG_DATA(cmsg), sizeof(received_pktinfo));
+ EXPECT_EQ(
+ memcmp(&received_pktinfo.ipi6_addr,
+ &(reinterpret_cast<sockaddr_in6*>(&sender_addr.addr)->sin6_addr),
+ sizeof(received_pktinfo.ipi6_addr)),
+ 0);
+ EXPECT_EQ(received_pktinfo.ipi6_ifindex,
+ ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex()));
+ EXPECT_THAT(CMSG_NXTHDR(&recv_msg, cmsg), IsNull());
+}
+
// Test that socket will receive IP_RECVORIGDSTADDR control message.
TEST_P(IPv6UDPUnboundSocketTest, SetAndReceiveIPReceiveOrigDstAddr) {
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc
index d72b6b53f..c6e563cd3 100644
--- a/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc
+++ b/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc
@@ -18,8 +18,6 @@ namespace gvisor {
namespace testing {
TEST_P(IPv6UDPUnboundExternalNetworkingSocketTest, TestJoinLeaveMulticast) {
- SKIP_IF(!found_net_interfaces_);
-
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index 3fbbf1423..9c0ff8780 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -706,7 +706,7 @@ TEST_P(TcpSocketTest, TcpInq) {
std::vector<char> control(CMSG_SPACE(sizeof(int)));
size = sizeof(buf);
struct iovec iov;
- for (int i = 0; size != 0; i += kChunk) {
+ while (size != 0) {
msg.msg_control = &control[0];
msg.msg_controllen = control.size();
@@ -2067,6 +2067,69 @@ TEST_P(SimpleTcpSocketTest, GetSocketAcceptConnWithShutdown) {
EXPECT_EQ(got, 0);
}
+void ShutdownConnectingSocket(int domain, int shutdown_mode) {
+ FileDescriptor bound_s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(domain, SOCK_STREAM, IPPROTO_TCP));
+
+ sockaddr_storage bound_addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(domain));
+ socklen_t bound_addrlen = sizeof(bound_addr);
+
+ ASSERT_THAT(bind(bound_s.get(), AsSockAddr(&bound_addr), bound_addrlen),
+ SyscallSucceeds());
+
+ // Start listening. Use a zero backlog to only allow one connection in the
+ // accept queue.
+ ASSERT_THAT(listen(bound_s.get(), 0), SyscallSucceeds());
+
+ // Get the addresses the socket is bound to because the port is chosen by the
+ // stack.
+ ASSERT_THAT(
+ getsockname(bound_s.get(), AsSockAddr(&bound_addr), &bound_addrlen),
+ SyscallSucceeds());
+
+ // Establish a connection. But do not accept it. That way, subsequent
+ // connections will not get a SYN-ACK because the queue is full.
+ FileDescriptor connected_s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(domain, SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(connect(connected_s.get(),
+ reinterpret_cast<const struct sockaddr*>(&bound_addr),
+ bound_addrlen),
+ SyscallSucceeds());
+
+ FileDescriptor connecting_s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(domain, SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ ASSERT_THAT(connect(connecting_s.get(),
+ reinterpret_cast<const struct sockaddr*>(&bound_addr),
+ bound_addrlen),
+ SyscallFailsWithErrno(EINPROGRESS));
+
+ // Now the test: when a connecting socket is shutdown, the socket should enter
+ // an error state.
+ EXPECT_THAT(shutdown(connecting_s.get(), shutdown_mode), SyscallSucceeds());
+
+ // We don't need to specify any events to get POLLHUP or POLLERR because these
+ // are always tracked.
+ struct pollfd poll_fd = {
+ .fd = connecting_s.get(),
+ };
+
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 0), SyscallSucceedsWithValue(1));
+ EXPECT_EQ(poll_fd.revents, POLLHUP | POLLERR);
+}
+
+TEST_P(SimpleTcpSocketTest, ShutdownReadConnectingSocket) {
+ ShutdownConnectingSocket(GetParam(), SHUT_RD);
+}
+
+TEST_P(SimpleTcpSocketTest, ShutdownWriteConnectingSocket) {
+ ShutdownConnectingSocket(GetParam(), SHUT_WR);
+}
+
+TEST_P(SimpleTcpSocketTest, ShutdownReadWriteConnectingSocket) {
+ ShutdownConnectingSocket(GetParam(), SHUT_RDWR);
+}
+
// Tests that connecting to an unspecified address results in ECONNREFUSED.
TEST_P(SimpleTcpSocketTest, ConnectUnspecifiedAddress) {
sockaddr_storage addr;
@@ -2088,6 +2151,66 @@ TEST_P(SimpleTcpSocketTest, ConnectUnspecifiedAddress) {
}
}
+TEST_P(SimpleTcpSocketTest, OnlyAcknowledgeBacklogConnections) {
+ // At some point, there was a bug in gVisor where a connection could be
+ // SYN-ACK'd by the server even if the accept queue was already full. This was
+ // possible because once the listener would process an ACK, it would move the
+ // new connection in the accept queue asynchronously. It created an
+ // opportunity where the listener could process another SYN before completing
+ // the delivery that would have filled the accept queue.
+ //
+ // This test checks that there is no such race.
+
+ std::array<std::optional<ScopedThread>, 100> threads;
+ for (auto& thread : threads) {
+ thread.emplace([]() {
+ FileDescriptor bound_s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ sockaddr_storage bound_addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t bound_addrlen = sizeof(bound_addr);
+
+ ASSERT_THAT(bind(bound_s.get(), AsSockAddr(&bound_addr), bound_addrlen),
+ SyscallSucceeds());
+
+ // Start listening. Use a zero backlog to only allow one connection in the
+ // accept queue.
+ ASSERT_THAT(listen(bound_s.get(), 0), SyscallSucceeds());
+
+ // Get the addresses the socket is bound to because the port is chosen by
+ // the stack.
+ ASSERT_THAT(
+ getsockname(bound_s.get(), AsSockAddr(&bound_addr), &bound_addrlen),
+ SyscallSucceeds());
+
+ // Establish a connection, but do not accept it.
+ FileDescriptor connected_s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(connect(connected_s.get(),
+ reinterpret_cast<const struct sockaddr*>(&bound_addr),
+ bound_addrlen),
+ SyscallSucceeds());
+
+ // Immediately attempt to establish another connection. Use non blocking
+ // socket because this is expected to timeout.
+ FileDescriptor connecting_s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ ASSERT_THAT(connect(connecting_s.get(),
+ reinterpret_cast<const struct sockaddr*>(&bound_addr),
+ bound_addrlen),
+ SyscallFailsWithErrno(EINPROGRESS));
+
+ struct pollfd poll_fd = {
+ .fd = connecting_s.get(),
+ .events = POLLOUT,
+ };
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10),
+ SyscallSucceedsWithValue(0));
+ });
+ }
+}
+
// Tests that send will return EWOULDBLOCK initially with large buffer and will
// succeed after the send buffer size is increased.
TEST_P(TcpSocketTest, SendUnblocksOnSendBufferIncrease) {
diff --git a/test/util/BUILD b/test/util/BUILD
index 5b6bcb32c..2dcf71613 100644
--- a/test/util/BUILD
+++ b/test/util/BUILD
@@ -350,18 +350,6 @@ cc_library(
)
cc_library(
- name = "benchmark_main",
- testonly = 1,
- linkstatic = 1,
- deps = [
- ":test_util",
- "@com_google_absl//absl/flags:flag",
- gtest,
- gbenchmark_internal,
- ],
-)
-
-cc_library(
name = "epoll_util",
testonly = 1,
srcs = ["epoll_util.cc"],
diff --git a/test/util/capability_util.h b/test/util/capability_util.h
index 318a43feb..ac1a1b32b 100644
--- a/test/util/capability_util.h
+++ b/test/util/capability_util.h
@@ -17,6 +17,8 @@
#ifndef GVISOR_TEST_UTIL_CAPABILITY_UTIL_H_
#define GVISOR_TEST_UTIL_CAPABILITY_UTIL_H_
+#include "test/util/posix_error.h"
+
#if defined(__Fuchsia__)
// Nothing to include.
#elif defined(__linux__)
diff --git a/test/util/cgroup_util.cc b/test/util/cgroup_util.cc
index df3c57b87..0308c2153 100644
--- a/test/util/cgroup_util.cc
+++ b/test/util/cgroup_util.cc
@@ -69,6 +69,9 @@ PosixError Cgroup::WriteControlFile(absl::string_view name,
const std::string& value) const {
ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, Open(Relpath(name), O_WRONLY));
RETURN_ERROR_IF_SYSCALL_FAIL(WriteFd(fd.get(), value.c_str(), value.size()));
+ const std::string alias_path = absl::StrFormat("[cg#%d]/%s", id_, name);
+ std::cerr << absl::StreamFormat("echo '%s' > %s", value, alias_path)
+ << std::endl;
return NoError();
}
diff --git a/test/util/socket_util.h b/test/util/socket_util.h
index 0e2be63cc..588f041b7 100644
--- a/test/util/socket_util.h
+++ b/test/util/socket_util.h
@@ -554,15 +554,27 @@ uint16_t ICMPChecksum(struct icmphdr icmphdr, const char* payload,
inline sockaddr* AsSockAddr(sockaddr_storage* s) {
return reinterpret_cast<sockaddr*>(s);
}
+inline const sockaddr* AsSockAddr(const sockaddr_storage* s) {
+ return reinterpret_cast<const sockaddr*>(s);
+}
inline sockaddr* AsSockAddr(sockaddr_in* s) {
return reinterpret_cast<sockaddr*>(s);
}
+inline const sockaddr* AsSockAddr(const sockaddr_in* s) {
+ return reinterpret_cast<const sockaddr*>(s);
+}
inline sockaddr* AsSockAddr(sockaddr_in6* s) {
return reinterpret_cast<sockaddr*>(s);
}
+inline const sockaddr* AsSockAddr(const sockaddr_in6* s) {
+ return reinterpret_cast<const sockaddr*>(s);
+}
inline sockaddr* AsSockAddr(sockaddr_un* s) {
return reinterpret_cast<sockaddr*>(s);
}
+inline const sockaddr* AsSockAddr(const sockaddr_un* s) {
+ return reinterpret_cast<const sockaddr*>(s);
+}
PosixErrorOr<uint16_t> AddrPort(int family, sockaddr_storage const& addr);
diff --git a/tools/bazel.mk b/tools/bazel.mk
index 1444423e4..68b804ec4 100644
--- a/tools/bazel.mk
+++ b/tools/bazel.mk
@@ -186,8 +186,8 @@ build_paths = \
(set -euo pipefail; \
$(call wrapper,$(BAZEL) build $(BASE_OPTIONS) $(BAZEL_OPTIONS) $(1)) && \
$(call wrapper,$(BAZEL) cquery $(BASE_OPTIONS) $(BAZEL_OPTIONS) $(1) --output=starlark --starlark:file=tools/show_paths.bzl) \
- | xargs -r -n 1 -I {} bash -c 'test -e "{}" || exit 0; readlink -f "{}"' \
- | xargs -r -n 1 -I {} bash -c 'set -euo pipefail; $(2)')
+ | xargs -r -I {} bash -c 'test -e "{}" || exit 0; readlink -f "{}"' \
+ | xargs -r -I {} bash -c 'set -euo pipefail; $(2)')
clean = $(call header,CLEAN) && $(call wrapper,$(BAZEL) clean)
build = $(call header,BUILD $(1)) && $(call build_paths,$(1),echo {})
diff --git a/tools/bigquery/bigquery.go b/tools/bigquery/bigquery.go
index 082410697..5aa1fe5dc 100644
--- a/tools/bigquery/bigquery.go
+++ b/tools/bigquery/bigquery.go
@@ -40,29 +40,45 @@ type Suite struct {
}
func (s *Suite) String() string {
- conditions := make([]string, 0, len(s.Conditions))
- for _, c := range s.Conditions {
- conditions = append(conditions, c.String())
- }
- benchmarks := make([]string, 0, len(s.Benchmarks))
- for _, b := range s.Benchmarks {
- benchmarks = append(benchmarks, b.String())
- }
+ var sb strings.Builder
+ s.debugString(&sb, "")
+ return sb.String()
+}
- format := `Suite:
-Name: %s
-Conditions: %s
-Benchmarks: %s
-Official: %t
-Timestamp: %s
-`
+// writeLine writes a line of text to the given string builder with a prefix.
+func writeLine(sb *strings.Builder, prefix string, format string, values ...interface{}) {
+ if prefix != "" {
+ sb.WriteString(prefix)
+ }
+ sb.WriteString(fmt.Sprintf(format, values...))
+ sb.WriteString("\n")
+}
- return fmt.Sprintf(format,
- s.Name,
- strings.Join(conditions, "\n"),
- strings.Join(benchmarks, "\n"),
- s.Official,
- s.Timestamp)
+// debugString writes debug information to the given string builder with the
+// given prefix.
+func (s *Suite) debugString(sb *strings.Builder, prefix string) {
+ writeLine(sb, prefix, "Benchmark suite %s:", s.Name)
+ writeLine(sb, prefix, "Timestamp: %v", s.Timestamp)
+ if !s.Official {
+ writeLine(sb, prefix, " **** NOTE: Data is not official. **** ")
+ }
+ if numConditions := len(s.Conditions); numConditions == 0 {
+ writeLine(sb, prefix, "Conditions: None.")
+ } else {
+ writeLine(sb, prefix, "Conditions (%d):", numConditions)
+ for _, condition := range s.Conditions {
+ condition.debugString(sb, prefix+" ")
+ }
+ }
+ if numBenchmarks := len(s.Benchmarks); numBenchmarks == 0 {
+ writeLine(sb, prefix, "Benchmarks: None.")
+ } else {
+ writeLine(sb, prefix, "Benchmarks (%d):", numBenchmarks)
+ for _, benchmark := range s.Benchmarks {
+ benchmark.debugString(sb, prefix+" ")
+ }
+ }
+ sb.WriteString(fmt.Sprintf("End of data for benchmark suite %s.", s.Name))
}
// Benchmark represents an individual benchmark in a suite.
@@ -74,25 +90,31 @@ type Benchmark struct {
// String implements the String method for Benchmark
func (bm *Benchmark) String() string {
- conditions := make([]string, 0, len(bm.Condition))
- for _, c := range bm.Condition {
- conditions = append(conditions, c.String())
+ var sb strings.Builder
+ bm.debugString(&sb, "")
+ return sb.String()
+}
+
+// debugString writes debug information to the given string builder with the
+// given prefix.
+func (bm *Benchmark) debugString(sb *strings.Builder, prefix string) {
+ writeLine(sb, prefix, "Benchmark: %s", bm.Name)
+ if numConditions := len(bm.Condition); numConditions == 0 {
+ writeLine(sb, prefix, " Conditions: None.")
+ } else {
+ writeLine(sb, prefix, " Conditions (%d):", numConditions)
+ for _, condition := range bm.Condition {
+ condition.debugString(sb, prefix+" ")
+ }
}
- metrics := make([]string, 0, len(bm.Metric))
- for _, m := range bm.Metric {
- metrics = append(metrics, m.String())
+ if numMetrics := len(bm.Metric); numMetrics == 0 {
+ writeLine(sb, prefix, " Metrics: None.")
+ } else {
+ writeLine(sb, prefix, " Metrics (%d):", numMetrics)
+ for _, metric := range bm.Metric {
+ metric.debugString(sb, prefix+" ")
+ }
}
-
- format := `Condition:
-Name: %s
-Conditions: %s
-Metrics: %s
-`
-
- return fmt.Sprintf(format,
- bm.Name,
- strings.Join(conditions, "\n"),
- strings.Join(metrics, "\n"))
}
// AddMetric adds a metric to an existing Benchmark.
@@ -107,10 +129,7 @@ func (bm *Benchmark) AddMetric(metricName, unit string, sample float64) {
// AddCondition adds a condition to an existing Benchmark.
func (bm *Benchmark) AddCondition(name, value string) {
- bm.Condition = append(bm.Condition, &Condition{
- Name: name,
- Value: value,
- })
+ bm.Condition = append(bm.Condition, NewCondition(name, value))
}
// NewBenchmark initializes a new benchmark.
@@ -136,8 +155,24 @@ type Condition struct {
Value string `bq:"value"`
}
+// NewCondition returns a new Condition with the given name and value.
+func NewCondition(name, value string) *Condition {
+ return &Condition{
+ Name: name,
+ Value: value,
+ }
+}
+
func (c *Condition) String() string {
- return fmt.Sprintf("Condition:\nName: %s Value: %s\n", c.Name, c.Value)
+ var sb strings.Builder
+ c.debugString(&sb, "")
+ return sb.String()
+}
+
+// debugString writes debug information to the given string builder with the
+// given prefix.
+func (c *Condition) debugString(sb *strings.Builder, prefix string) {
+ writeLine(sb, prefix, "Condition: %s = %s", c.Name, c.Value)
}
// Metric holds the actual metric data and unit information for this benchmark.
@@ -148,7 +183,15 @@ type Metric struct {
}
func (m *Metric) String() string {
- return fmt.Sprintf("Metric:\nName: %s Unit: %s Sample: %e\n", m.Name, m.Unit, m.Sample)
+ var sb strings.Builder
+ m.debugString(&sb, "")
+ return sb.String()
+}
+
+// debugString writes debug information to the given string builder with the
+// given prefix.
+func (m *Metric) debugString(sb *strings.Builder, prefix string) {
+ writeLine(sb, prefix, "Metric %s: %f %s", m.Name, m.Sample, m.Unit)
}
// InitBigQuery initializes a BigQuery dataset/table in the project. If the dataset/table already exists, it is not duplicated.
diff --git a/tools/checklocks/analysis.go b/tools/checklocks/analysis.go
index d3fd797d0..ec0cba7f9 100644
--- a/tools/checklocks/analysis.go
+++ b/tools/checklocks/analysis.go
@@ -183,8 +183,8 @@ type instructionWithReferrers interface {
// checkFieldAccess checks the validity of a field access.
//
// This also enforces atomicity constraints for fields that must be accessed
-// atomically. The parameter isWrite indicates whether this field is used
-// downstream for a write operation.
+// atomically. The parameter isWrite indicates whether this field is used for
+// a write operation.
func (pc *passContext) checkFieldAccess(inst instructionWithReferrers, structObj ssa.Value, field int, ls *lockState, isWrite bool) {
var (
lff lockFieldFacts
@@ -200,13 +200,14 @@ func (pc *passContext) checkFieldAccess(inst instructionWithReferrers, structObj
for guardName, fl := range lgf.GuardedBy {
guardsFound++
r := fl.resolve(structObj)
- if _, ok := ls.isHeld(r); ok {
+ if _, ok := ls.isHeld(r, isWrite); ok {
guardsHeld++
continue
}
if _, ok := pc.forced[pc.positionKey(inst.Pos())]; ok {
- // Mark this as locked, since it has been forced.
- ls.lockField(r)
+ // Mark this as locked, since it has been forced. All
+ // forces are treated as an exclusive lock.
+ ls.lockField(r, true /* exclusive */)
guardsHeld++
continue
}
@@ -301,7 +302,7 @@ func (pc *passContext) postFunctionCallUpdate(call callCommon, lff *lockFunction
continue
}
r := fg.resolveCall(call.Common().Args, call.Value())
- if s, ok := ls.unlockField(r); !ok {
+ if s, ok := ls.unlockField(r, fg.Exclusive); !ok {
if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok {
pc.maybeFail(call.Pos(), "attempt to release %s (%s), but not held (locks: %s)", fieldName, s, ls.String())
}
@@ -315,7 +316,7 @@ func (pc *passContext) postFunctionCallUpdate(call callCommon, lff *lockFunction
}
// Acquire the lock per the annotation.
r := fg.resolveCall(call.Common().Args, call.Value())
- if s, ok := ls.lockField(r); !ok {
+ if s, ok := ls.lockField(r, fg.Exclusive); !ok {
if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok {
pc.maybeFail(call.Pos(), "attempt to acquire %s (%s), but already held (locks: %s)", fieldName, s, ls.String())
}
@@ -323,6 +324,14 @@ func (pc *passContext) postFunctionCallUpdate(call callCommon, lff *lockFunction
}
}
+// exclusiveStr returns a string describing exclusive requirements.
+func exclusiveStr(exclusive bool) string {
+ if exclusive {
+ return "exclusively"
+ }
+ return "non-exclusively"
+}
+
// checkFunctionCall checks preconditions for function calls, and tracks the
// lock state by recording relevant calls to sync functions. Note that calls to
// atomic functions are tracked by checkFieldAccess by looking directly at the
@@ -332,12 +341,12 @@ func (pc *passContext) checkFunctionCall(call callCommon, fn *ssa.Function, lff
// Check all guards required are held.
for fieldName, fg := range lff.HeldOnEntry {
r := fg.resolveCall(call.Common().Args, call.Value())
- if s, ok := ls.isHeld(r); !ok {
+ if s, ok := ls.isHeld(r, fg.Exclusive); !ok {
if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok {
- pc.maybeFail(call.Pos(), "must hold %s (%s) to call %s, but not held (locks: %s)", fieldName, s, fn.Name(), ls.String())
+ pc.maybeFail(call.Pos(), "must hold %s %s (%s) to call %s, but not held (locks: %s)", fieldName, exclusiveStr(fg.Exclusive), s, fn.Name(), ls.String())
} else {
// Force the lock to be acquired.
- ls.lockField(r)
+ ls.lockField(r, fg.Exclusive)
}
}
}
@@ -348,19 +357,33 @@ func (pc *passContext) checkFunctionCall(call callCommon, fn *ssa.Function, lff
// Check if it's a method dispatch for something in the sync package.
// See: https://godoc.org/golang.org/x/tools/go/ssa#Function
if fn.Package() != nil && fn.Package().Pkg.Name() == "sync" && fn.Signature.Recv() != nil {
+ isExclusive := false
switch fn.Name() {
- case "Lock", "RLock":
- if s, ok := ls.lockField(resolvedValue{value: call.Common().Args[0], valid: true}); !ok {
+ case "Lock":
+ isExclusive = true
+ fallthrough
+ case "RLock":
+ if s, ok := ls.lockField(resolvedValue{value: call.Common().Args[0], valid: true}, isExclusive); !ok {
if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok {
// Double locking a mutex that is already locked.
pc.maybeFail(call.Pos(), "%s already locked (locks: %s)", s, ls.String())
}
}
- case "Unlock", "RUnlock":
- if s, ok := ls.unlockField(resolvedValue{value: call.Common().Args[0], valid: true}); !ok {
+ case "Unlock":
+ isExclusive = true
+ fallthrough
+ case "RUnlock":
+ if s, ok := ls.unlockField(resolvedValue{value: call.Common().Args[0], valid: true}, isExclusive); !ok {
if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok {
// Unlocking something that is already unlocked.
- pc.maybeFail(call.Pos(), "%s already unlocked (locks: %s)", s, ls.String())
+ pc.maybeFail(call.Pos(), "%s already unlocked or locked differently (locks: %s)", s, ls.String())
+ }
+ }
+ case "DowngradeLock":
+ if s, ok := ls.downgradeField(resolvedValue{value: call.Common().Args[0], valid: true}); !ok {
+ if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok {
+ // Downgrading something that may not be downgraded.
+ pc.maybeFail(call.Pos(), "%s already unlocked or not exclusive (locks: %s)", s, ls.String())
}
}
}
@@ -531,13 +554,13 @@ func (pc *passContext) checkBasicBlock(fn *ssa.Function, block *ssa.BasicBlock,
// Validate held locks.
for fieldName, fg := range lff.HeldOnExit {
r := fg.resolveStatic(fn, rv)
- if s, ok := rls.isHeld(r); !ok {
+ if s, ok := rls.isHeld(r, fg.Exclusive); !ok {
if _, ok := pc.forced[pc.positionKey(rv.Pos())]; !ok {
- pc.maybeFail(rv.Pos(), "lock %s (%s) not held (locks: %s)", fieldName, s, rls.String())
+ pc.maybeFail(rv.Pos(), "lock %s (%s) not held %s (locks: %s)", fieldName, s, exclusiveStr(fg.Exclusive), rls.String())
failed = true
} else {
// Force the lock to be acquired.
- rls.lockField(r)
+ rls.lockField(r, fg.Exclusive)
}
}
}
@@ -558,7 +581,7 @@ func (pc *passContext) checkBasicBlock(fn *ssa.Function, block *ssa.BasicBlock,
if pls := pc.checkBasicBlock(fn, succ, lff, ls, seen); pls != nil {
if rls != nil && !rls.isCompatible(pls) {
if _, ok := pc.forced[pc.positionKey(fn.Pos())]; !ok {
- pc.maybeFail(fn.Pos(), "incompatible return states (first: %s, second: %v)", rls.String(), pls.String())
+ pc.maybeFail(fn.Pos(), "incompatible return states (first: %s, second: %s)", rls.String(), pls.String())
}
}
rls = pls
@@ -601,11 +624,11 @@ func (pc *passContext) checkFunction(call callCommon, fn *ssa.Function, lff *loc
// The first is the method object itself so we skip that when looking
// for receiver/function parameters.
r := fg.resolveStatic(fn, call.Value())
- if s, ok := ls.lockField(r); !ok {
+ if s, ok := ls.lockField(r, fg.Exclusive); !ok {
// This can only happen if the same value is declared
// multiple times, and should be caught by the earlier
// fact scanning. Keep it here as a sanity check.
- pc.maybeFail(fn.Pos(), "lock %s (%s) acquired multiple times (locks: %s)", fieldName, s, ls.String())
+ pc.maybeFail(fn.Pos(), "lock %s (%s) acquired multiple times or differently (locks: %s)", fieldName, s, ls.String())
}
}
diff --git a/tools/checklocks/annotations.go b/tools/checklocks/annotations.go
index 371260980..1f679e5be 100644
--- a/tools/checklocks/annotations.go
+++ b/tools/checklocks/annotations.go
@@ -23,13 +23,16 @@ import (
)
const (
- checkLocksAnnotation = "// +checklocks:"
- checkLocksAcquires = "// +checklocksacquire:"
- checkLocksReleases = "// +checklocksrelease:"
- checkLocksIgnore = "// +checklocksignore"
- checkLocksForce = "// +checklocksforce"
- checkLocksFail = "// +checklocksfail"
- checkAtomicAnnotation = "// +checkatomic"
+ checkLocksAnnotation = "// +checklocks:"
+ checkLocksAnnotationRead = "// +checklocksread:"
+ checkLocksAcquires = "// +checklocksacquire:"
+ checkLocksAcquiresRead = "// +checklocksacquireread:"
+ checkLocksReleases = "// +checklocksrelease:"
+ checkLocksReleasesRead = "// +checklocksreleaseread:"
+ checkLocksIgnore = "// +checklocksignore"
+ checkLocksForce = "// +checklocksforce"
+ checkLocksFail = "// +checklocksfail"
+ checkAtomicAnnotation = "// +checkatomic"
)
// failData indicates an expected failure.
diff --git a/tools/checklocks/facts.go b/tools/checklocks/facts.go
index 34c9f5ef1..fd681adc3 100644
--- a/tools/checklocks/facts.go
+++ b/tools/checklocks/facts.go
@@ -166,6 +166,9 @@ type functionGuard struct {
// FieldList is the traversal path to the object.
FieldList fieldList
+
+ // Exclusive indicates an exclusive lock is required.
+ Exclusive bool
}
// resolveReturn resolves a return value.
@@ -262,7 +265,7 @@ type lockFunctionFacts struct {
func (*lockFunctionFacts) AFact() {}
// checkGuard validates the guardName.
-func (lff *lockFunctionFacts) checkGuard(pc *passContext, d *ast.FuncDecl, guardName string, allowReturn bool) (functionGuard, bool) {
+func (lff *lockFunctionFacts) checkGuard(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool, allowReturn bool) (functionGuard, bool) {
if _, ok := lff.HeldOnEntry[guardName]; ok {
pc.maybeFail(d.Pos(), "annotation %s specified more than once, already required", guardName)
return functionGuard{}, false
@@ -271,13 +274,13 @@ func (lff *lockFunctionFacts) checkGuard(pc *passContext, d *ast.FuncDecl, guard
pc.maybeFail(d.Pos(), "annotation %s specified more than once, already acquired", guardName)
return functionGuard{}, false
}
- fg, ok := pc.findFunctionGuard(d, guardName, allowReturn)
+ fg, ok := pc.findFunctionGuard(d, guardName, exclusive, allowReturn)
return fg, ok
}
// addGuardedBy adds a field to both HeldOnEntry and HeldOnExit.
-func (lff *lockFunctionFacts) addGuardedBy(pc *passContext, d *ast.FuncDecl, guardName string) {
- if fg, ok := lff.checkGuard(pc, d, guardName, false /* allowReturn */); ok {
+func (lff *lockFunctionFacts) addGuardedBy(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool) {
+ if fg, ok := lff.checkGuard(pc, d, guardName, exclusive, false /* allowReturn */); ok {
if lff.HeldOnEntry == nil {
lff.HeldOnEntry = make(map[string]functionGuard)
}
@@ -290,8 +293,8 @@ func (lff *lockFunctionFacts) addGuardedBy(pc *passContext, d *ast.FuncDecl, gua
}
// addAcquires adds a field to HeldOnExit.
-func (lff *lockFunctionFacts) addAcquires(pc *passContext, d *ast.FuncDecl, guardName string) {
- if fg, ok := lff.checkGuard(pc, d, guardName, true /* allowReturn */); ok {
+func (lff *lockFunctionFacts) addAcquires(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool) {
+ if fg, ok := lff.checkGuard(pc, d, guardName, exclusive, true /* allowReturn */); ok {
if lff.HeldOnExit == nil {
lff.HeldOnExit = make(map[string]functionGuard)
}
@@ -300,8 +303,8 @@ func (lff *lockFunctionFacts) addAcquires(pc *passContext, d *ast.FuncDecl, guar
}
// addReleases adds a field to HeldOnEntry.
-func (lff *lockFunctionFacts) addReleases(pc *passContext, d *ast.FuncDecl, guardName string) {
- if fg, ok := lff.checkGuard(pc, d, guardName, false /* allowReturn */); ok {
+func (lff *lockFunctionFacts) addReleases(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool) {
+ if fg, ok := lff.checkGuard(pc, d, guardName, exclusive, false /* allowReturn */); ok {
if lff.HeldOnEntry == nil {
lff.HeldOnEntry = make(map[string]functionGuard)
}
@@ -310,7 +313,7 @@ func (lff *lockFunctionFacts) addReleases(pc *passContext, d *ast.FuncDecl, guar
}
// fieldListFor returns the fieldList for the given object.
-func (pc *passContext) fieldListFor(pos token.Pos, fieldObj types.Object, index int, fieldName string, checkMutex bool) (int, bool) {
+func (pc *passContext) fieldListFor(pos token.Pos, fieldObj types.Object, index int, fieldName string, checkMutex bool, exclusive bool) (int, bool) {
var lff lockFieldFacts
if !pc.pass.ImportObjectFact(fieldObj, &lff) {
// This should not happen: we export facts for all fields.
@@ -318,7 +321,11 @@ func (pc *passContext) fieldListFor(pos token.Pos, fieldObj types.Object, index
}
// Check that it is indeed a mutex.
if checkMutex && !lff.IsMutex && !lff.IsRWMutex {
- pc.maybeFail(pos, "field %s is not a mutex or an rwmutex", fieldName)
+ pc.maybeFail(pos, "field %s is not a Mutex or an RWMutex", fieldName)
+ return 0, false
+ }
+ if checkMutex && !exclusive && !lff.IsRWMutex {
+ pc.maybeFail(pos, "field %s must be a RWMutex, but it is not", fieldName)
return 0, false
}
// Return the resolution path.
@@ -329,14 +336,14 @@ func (pc *passContext) fieldListFor(pos token.Pos, fieldObj types.Object, index
}
// resolveOneField resolves a field in a single struct.
-func (pc *passContext) resolveOneField(pos token.Pos, structType *types.Struct, fieldName string, checkMutex bool) (fl fieldList, fieldObj types.Object, ok bool) {
+func (pc *passContext) resolveOneField(pos token.Pos, structType *types.Struct, fieldName string, checkMutex bool, exclusive bool) (fl fieldList, fieldObj types.Object, ok bool) {
// Scan to match the next field.
for i := 0; i < structType.NumFields(); i++ {
fieldObj := structType.Field(i)
if fieldObj.Name() != fieldName {
continue
}
- flOne, ok := pc.fieldListFor(pos, fieldObj, i, fieldName, checkMutex)
+ flOne, ok := pc.fieldListFor(pos, fieldObj, i, fieldName, checkMutex, exclusive)
if !ok {
return nil, nil, false
}
@@ -357,8 +364,8 @@ func (pc *passContext) resolveOneField(pos token.Pos, structType *types.Struct,
// Need to check that there is a resolution path. If there is
// no resolution path that's not a failure: we just continue
// scanning the next embed to find a match.
- flEmbed, okEmbed := pc.fieldListFor(pos, fieldObj, i, fieldName, false)
- flCont, fieldObjCont, okCont := pc.resolveOneField(pos, structType, fieldName, checkMutex)
+ flEmbed, okEmbed := pc.fieldListFor(pos, fieldObj, i, fieldName, false, exclusive)
+ flCont, fieldObjCont, okCont := pc.resolveOneField(pos, structType, fieldName, checkMutex, exclusive)
if okEmbed && okCont {
fl = append(fl, flEmbed)
fl = append(fl, flCont...)
@@ -373,9 +380,9 @@ func (pc *passContext) resolveOneField(pos token.Pos, structType *types.Struct,
//
// Note that this checks that the final element is a mutex of some kind, and
// will fail appropriately.
-func (pc *passContext) resolveField(pos token.Pos, structType *types.Struct, parts []string) (fl fieldList, ok bool) {
+func (pc *passContext) resolveField(pos token.Pos, structType *types.Struct, parts []string, exclusive bool) (fl fieldList, ok bool) {
for partNumber, fieldName := range parts {
- flOne, fieldObj, ok := pc.resolveOneField(pos, structType, fieldName, partNumber >= len(parts)-1 /* checkMutex */)
+ flOne, fieldObj, ok := pc.resolveOneField(pos, structType, fieldName, partNumber >= len(parts)-1 /* checkMutex */, exclusive)
if !ok {
// Error already reported.
return nil, false
@@ -474,16 +481,22 @@ func (pc *passContext) exportLockGuardFacts(structType *types.Struct, ss *ast.St
pc.maybeFail(fieldObj.Pos(), "annotation %s specified more than once", guardName)
return
}
- fl, ok := pc.resolveField(fieldObj.Pos(), structType, strings.Split(guardName, "."))
+ fl, ok := pc.resolveField(fieldObj.Pos(), structType, strings.Split(guardName, "."), true /* exclusive */)
if ok {
- // If we successfully resolved
- // the field, then save it.
+ // If we successfully resolved the field, then save it.
if lgf.GuardedBy == nil {
lgf.GuardedBy = make(map[string]fieldList)
}
lgf.GuardedBy[guardName] = fl
}
},
+ // N.B. We support only the vanilla
+ // annotation on individual fields. If
+ // the field is a read lock, then we
+ // will allow read access by default.
+ checkLocksAnnotationRead: func(guardName string) {
+ pc.maybeFail(fieldObj.Pos(), "annotation %s not legal on fields", guardName)
+ },
})
}
// Save only if there is something meaningful.
@@ -514,7 +527,7 @@ func countFields(fl []*ast.Field) (count int) {
}
// matchFieldList attempts to match the given field.
-func (pc *passContext) matchFieldList(pos token.Pos, fl []*ast.Field, guardName string) (functionGuard, bool) {
+func (pc *passContext) matchFieldList(pos token.Pos, fl []*ast.Field, guardName string, exclusive bool) (functionGuard, bool) {
parts := strings.Split(guardName, ".")
parameterName := parts[0]
parameterNumber := 0
@@ -545,8 +558,9 @@ func (pc *passContext) matchFieldList(pos token.Pos, fl []*ast.Field, guardName
}
fg := functionGuard{
ParameterNumber: parameterNumber,
+ Exclusive: exclusive,
}
- fl, ok := pc.resolveField(pos, structType, parts[1:])
+ fl, ok := pc.resolveField(pos, structType, parts[1:], exclusive)
fg.FieldList = fl
return fg, ok // If ok is false, already failed.
}
@@ -558,7 +572,7 @@ func (pc *passContext) matchFieldList(pos token.Pos, fl []*ast.Field, guardName
// particular string of the 'a.b'.
//
// This function will report any errors directly.
-func (pc *passContext) findFunctionGuard(d *ast.FuncDecl, guardName string, allowReturn bool) (functionGuard, bool) {
+func (pc *passContext) findFunctionGuard(d *ast.FuncDecl, guardName string, exclusive bool, allowReturn bool) (functionGuard, bool) {
var (
parameterList []*ast.Field
returnList []*ast.Field
@@ -569,14 +583,14 @@ func (pc *passContext) findFunctionGuard(d *ast.FuncDecl, guardName string, allo
if d.Type.Params != nil {
parameterList = append(parameterList, d.Type.Params.List...)
}
- if fg, ok := pc.matchFieldList(d.Pos(), parameterList, guardName); ok {
+ if fg, ok := pc.matchFieldList(d.Pos(), parameterList, guardName, exclusive); ok {
return fg, ok
}
if allowReturn {
if d.Type.Results != nil {
returnList = append(returnList, d.Type.Results.List...)
}
- if fg, ok := pc.matchFieldList(d.Pos(), returnList, guardName); ok {
+ if fg, ok := pc.matchFieldList(d.Pos(), returnList, guardName, exclusive); ok {
// Fix this up to apply to the return value, as noted
// in fg.ParameterNumber. For the ssa analysis, we must
// record whether this has multiple results, since
@@ -610,9 +624,12 @@ func (pc *passContext) exportFunctionFacts(d *ast.FuncDecl) {
// extremely rare.
lff.Ignore = true
},
- checkLocksAnnotation: func(guardName string) { lff.addGuardedBy(pc, d, guardName) },
- checkLocksAcquires: func(guardName string) { lff.addAcquires(pc, d, guardName) },
- checkLocksReleases: func(guardName string) { lff.addReleases(pc, d, guardName) },
+ checkLocksAnnotation: func(guardName string) { lff.addGuardedBy(pc, d, guardName, true /* exclusive */) },
+ checkLocksAnnotationRead: func(guardName string) { lff.addGuardedBy(pc, d, guardName, false /* exclusive */) },
+ checkLocksAcquires: func(guardName string) { lff.addAcquires(pc, d, guardName, true /* exclusive */) },
+ checkLocksAcquiresRead: func(guardName string) { lff.addAcquires(pc, d, guardName, false /* exclusive */) },
+ checkLocksReleases: func(guardName string) { lff.addReleases(pc, d, guardName, true /* exclusive */) },
+ checkLocksReleasesRead: func(guardName string) { lff.addReleases(pc, d, guardName, false /* exclusive */) },
})
}
diff --git a/tools/checklocks/state.go b/tools/checklocks/state.go
index 57061a32e..aaf997d79 100644
--- a/tools/checklocks/state.go
+++ b/tools/checklocks/state.go
@@ -29,7 +29,9 @@ type lockState struct {
// lockedMutexes is used to track which mutexes in a given struct are
// currently locked. Note that most of the heavy lifting is done by
// valueAsString below, which maps to specific structure fields, etc.
- lockedMutexes []string
+ //
+ // The value indicates whether this is an exclusive lock.
+ lockedMutexes map[string]bool
// stored stores values that have been stored in memory, bound to
// FreeVars or passed as Parameterse.
@@ -51,7 +53,7 @@ type lockState struct {
func newLockState() *lockState {
refs := int32(1) // Not shared.
return &lockState{
- lockedMutexes: make([]string, 0),
+ lockedMutexes: make(map[string]bool),
used: make(map[ssa.Value]struct{}),
stored: make(map[ssa.Value]ssa.Value),
defers: make([]*ssa.Defer, 0),
@@ -79,8 +81,10 @@ func (l *lockState) fork() *lockState {
func (l *lockState) modify() {
if atomic.LoadInt32(l.refs) > 1 {
// Copy the lockedMutexes.
- lm := make([]string, len(l.lockedMutexes))
- copy(lm, l.lockedMutexes)
+ lm := make(map[string]bool)
+ for k, v := range l.lockedMutexes {
+ lm[k] = v
+ }
l.lockedMutexes = lm
// Copy the stored values.
@@ -106,55 +110,76 @@ func (l *lockState) modify() {
}
// isHeld indicates whether the field is held is not.
-func (l *lockState) isHeld(rv resolvedValue) (string, bool) {
+func (l *lockState) isHeld(rv resolvedValue, exclusiveRequired bool) (string, bool) {
if !rv.valid {
return rv.valueAsString(l), false
}
s := rv.valueAsString(l)
- for _, k := range l.lockedMutexes {
- if k == s {
- return s, true
- }
+ isExclusive, ok := l.lockedMutexes[s]
+ if !ok {
+ return s, false
+ }
+ // Accept a weaker lock if exclusiveRequired is false.
+ if exclusiveRequired && !isExclusive {
+ return s, false
}
- return s, false
+ return s, true
}
// lockField locks the given field.
//
// If false is returned, the field was already locked.
-func (l *lockState) lockField(rv resolvedValue) (string, bool) {
+func (l *lockState) lockField(rv resolvedValue, exclusive bool) (string, bool) {
if !rv.valid {
return rv.valueAsString(l), false
}
s := rv.valueAsString(l)
- for _, k := range l.lockedMutexes {
- if k == s {
- return s, false
- }
+ if _, ok := l.lockedMutexes[s]; ok {
+ return s, false
}
l.modify()
- l.lockedMutexes = append(l.lockedMutexes, s)
+ l.lockedMutexes[s] = exclusive
return s, true
}
// unlockField unlocks the given field.
//
// If false is returned, the field was not locked.
-func (l *lockState) unlockField(rv resolvedValue) (string, bool) {
+func (l *lockState) unlockField(rv resolvedValue, exclusive bool) (string, bool) {
if !rv.valid {
return rv.valueAsString(l), false
}
s := rv.valueAsString(l)
- for i, k := range l.lockedMutexes {
- if k == s {
- // Copy the last lock in and truncate.
- l.modify()
- l.lockedMutexes[i] = l.lockedMutexes[len(l.lockedMutexes)-1]
- l.lockedMutexes = l.lockedMutexes[:len(l.lockedMutexes)-1]
- return s, true
- }
+ wasExclusive, ok := l.lockedMutexes[s]
+ if !ok {
+ return s, false
}
- return s, false
+ if wasExclusive != exclusive {
+ return s, false
+ }
+ l.modify()
+ delete(l.lockedMutexes, s)
+ return s, true
+}
+
+// downgradeField downgrades the given field.
+//
+// If false was returned, the field was not downgraded.
+func (l *lockState) downgradeField(rv resolvedValue) (string, bool) {
+ if !rv.valid {
+ return rv.valueAsString(l), false
+ }
+ s := rv.valueAsString(l)
+ wasExclusive, ok := l.lockedMutexes[s]
+ if !ok {
+ return s, false
+ }
+ if !wasExclusive {
+ return s, false
+ }
+ l.modify()
+ l.lockedMutexes[s] = false // Downgraded.
+ return s, true
}
// store records an alias.
@@ -165,16 +190,17 @@ func (l *lockState) store(addr ssa.Value, v ssa.Value) {
// isSubset indicates other holds all the locks held by l.
func (l *lockState) isSubset(other *lockState) bool {
- held := 0 // Number in l, held by other.
- for _, k := range l.lockedMutexes {
- for _, ok := range other.lockedMutexes {
- if k == ok {
- held++
- break
- }
+ for k, isExclusive := range l.lockedMutexes {
+ otherExclusive, otherOk := other.lockedMutexes[k]
+ if !otherOk {
+ return false
+ }
+ // Accept weaker locks as a subset.
+ if isExclusive && !otherExclusive {
+ return false
}
}
- return held >= len(l.lockedMutexes)
+ return true
}
// count indicates the number of locks held.
@@ -293,7 +319,12 @@ func (l *lockState) String() string {
if l.count() == 0 {
return "no locks held"
}
- return strings.Join(l.lockedMutexes, ",")
+ keys := make([]string, 0, len(l.lockedMutexes))
+ for k, exclusive := range l.lockedMutexes {
+ // Include the exclusive status of each lock.
+ keys = append(keys, fmt.Sprintf("%s %s", k, exclusiveStr(exclusive)))
+ }
+ return strings.Join(keys, ",")
}
// pushDefer pushes a defer onto the stack.
diff --git a/tools/checklocks/test/BUILD b/tools/checklocks/test/BUILD
index d4d98c256..f2ea6c7c6 100644
--- a/tools/checklocks/test/BUILD
+++ b/tools/checklocks/test/BUILD
@@ -16,6 +16,7 @@ go_library(
"methods.go",
"parameters.go",
"return.go",
+ "rwmutex.go",
"test.go",
],
)
diff --git a/tools/checklocks/test/basics.go b/tools/checklocks/test/basics.go
index 7a773171f..e941fba5b 100644
--- a/tools/checklocks/test/basics.go
+++ b/tools/checklocks/test/basics.go
@@ -108,14 +108,14 @@ type rwGuardStruct struct {
func testRWValidRead(tc *rwGuardStruct) {
tc.rwMu.Lock()
- tc.guardedField = 1
+ _ = tc.guardedField
tc.rwMu.Unlock()
}
func testRWValidWrite(tc *rwGuardStruct) {
- tc.rwMu.RLock()
+ tc.rwMu.Lock()
tc.guardedField = 2
- tc.rwMu.RUnlock()
+ tc.rwMu.Unlock()
}
func testRWInvalidWrite(tc *rwGuardStruct) {
diff --git a/tools/checklocks/test/rwmutex.go b/tools/checklocks/test/rwmutex.go
new file mode 100644
index 000000000..d27ed10e3
--- /dev/null
+++ b/tools/checklocks/test/rwmutex.go
@@ -0,0 +1,52 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package test
+
+import (
+ "sync"
+)
+
+// oneReadGuardStruct has one read-guarded field.
+type oneReadGuardStruct struct {
+ mu sync.RWMutex
+ // +checklocks:mu
+ guardedField int
+}
+
+func testRWAccessValidRead(tc *oneReadGuardStruct) {
+ tc.mu.Lock()
+ _ = tc.guardedField
+ tc.mu.Unlock()
+ tc.mu.RLock()
+ _ = tc.guardedField
+ tc.mu.RUnlock()
+}
+
+func testRWAccessValidWrite(tc *oneReadGuardStruct) {
+ tc.mu.Lock()
+ tc.guardedField = 1
+ tc.mu.Unlock()
+}
+
+func testRWAccessInvalidWrite(tc *oneReadGuardStruct) {
+ tc.guardedField = 2 // +checklocksfail
+ tc.mu.RLock()
+ tc.guardedField = 2 // +checklocksfail
+ tc.mu.RUnlock()
+}
+
+func testRWAccessInvalidRead(tc *oneReadGuardStruct) {
+ _ = tc.guardedField // +checklocksfail
+}
diff --git a/tools/go_fieldenum/main.go b/tools/go_fieldenum/main.go
index 68dfdb3db..d801bea1b 100644
--- a/tools/go_fieldenum/main.go
+++ b/tools/go_fieldenum/main.go
@@ -55,6 +55,7 @@ func main() {
// Determine which types are marked "+fieldenum" and will consequently have
// code generated.
+ var typeNames []string
fieldEnumTypes := make(map[string]fieldEnumTypeInfo)
for _, f := range inputFiles {
for _, decl := range f.Decls {
@@ -75,6 +76,7 @@ func main() {
if !ok {
log.Fatalf("Type %s is marked +fieldenum, but is not a struct", name)
}
+ typeNames = append(typeNames, name)
fieldEnumTypes[name] = fieldEnumTypeInfo{
prefix: prefix,
structType: st,
@@ -86,9 +88,10 @@ func main() {
}
// Collect information for each type for which code is being generated.
- structInfos := make([]structInfo, 0, len(fieldEnumTypes))
+ structInfos := make([]structInfo, 0, len(typeNames))
needSyncAtomic := false
- for typeName, typeInfo := range fieldEnumTypes {
+ for _, typeName := range typeNames {
+ typeInfo := fieldEnumTypes[typeName]
var si structInfo
si.name = typeName
si.prefix = typeInfo.prefix
@@ -204,13 +207,6 @@ func structFieldName(f *ast.Field) string {
}
}
-// Workaround for Go defect (map membership test isn't usable in an
-// expression).
-func fetContains(xs map[string]*ast.StructType, x string) bool {
- _, ok := xs[x]
- return ok
-}
-
func (si *structInfo) writeTo(b *strings.Builder) {
fmt.Fprintf(b, "// A %sField represents a field in %s.\n", si.prefix, si.name)
fmt.Fprintf(b, "type %sField uint\n\n", si.prefix)
diff --git a/tools/nogo/build.go b/tools/nogo/build.go
index 4067bb480..003533c71 100644
--- a/tools/nogo/build.go
+++ b/tools/nogo/build.go
@@ -31,3 +31,8 @@ func findStdPkg(GOOS, GOARCH, path string) (io.ReadCloser, error) {
}
return os.Open(fmt.Sprintf("external/go_sdk/pkg/%s_%s/%s.a", GOOS, GOARCH, path))
}
+
+// ReleaseTags returns nil, indicating that the defaults should be used.
+func ReleaseTags() ([]string, error) {
+ return nil, nil
+}
diff --git a/tools/nogo/check/main.go b/tools/nogo/check/main.go
index 0e7e92965..17ca0d846 100644
--- a/tools/nogo/check/main.go
+++ b/tools/nogo/check/main.go
@@ -66,14 +66,22 @@ func run([]string) int {
return 1
}
+ releaseTags, err := nogo.ReleaseTags()
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "error determining release tags: %v", err)
+ return 1
+ }
+
// Run the configuration.
if *stdlibFile != "" {
// Perform stdlib analysis.
c := loadConfig(*stdlibFile, new(nogo.StdlibConfig)).(*nogo.StdlibConfig)
+ c.ReleaseTags = releaseTags
findings, factData, err = nogo.CheckStdlib(c, nogo.AllAnalyzers)
} else if *packageFile != "" {
// Perform standard analysis.
c := loadConfig(*packageFile, new(nogo.PackageConfig)).(*nogo.PackageConfig)
+ c.ReleaseTags = releaseTags
findings, factData, err = nogo.CheckPackage(c, nogo.AllAnalyzers, nil)
} else {
fmt.Fprintf(os.Stderr, "please provide at least one of package or stdlib!")
diff --git a/tools/nogo/defs.bzl b/tools/nogo/defs.bzl
index dc9a8b24e..bc0d6874f 100644
--- a/tools/nogo/defs.bzl
+++ b/tools/nogo/defs.bzl
@@ -126,7 +126,7 @@ def _nogo_stdlib_impl(ctx):
Srcs = [f.path for f in go_ctx.stdlib_srcs],
GOOS = go_ctx.goos,
GOARCH = go_ctx.goarch,
- Tags = go_ctx.gotags,
+ BuildTags = go_ctx.gotags,
)
config_file = ctx.actions.declare_file(ctx.label.name + ".cfg")
ctx.actions.write(config_file, config.to_json())
@@ -321,7 +321,7 @@ def _nogo_aspect_impl(target, ctx):
NonGoFiles = [src.path for src in srcs if not src.path.endswith(".go")],
GOOS = go_ctx.goos,
GOARCH = go_ctx.goarch,
- Tags = go_ctx.gotags,
+ BuildTags = go_ctx.gotags,
FactMap = fact_map,
ImportMap = import_map,
StdlibFacts = stdlib_facts.path,
diff --git a/tools/nogo/nogo.go b/tools/nogo/nogo.go
index d95d7652f..a96cb400a 100644
--- a/tools/nogo/nogo.go
+++ b/tools/nogo/nogo.go
@@ -52,10 +52,11 @@ import (
//
// This contains everything required for stdlib analysis.
type StdlibConfig struct {
- Srcs []string
- GOOS string
- GOARCH string
- Tags []string
+ Srcs []string
+ GOOS string
+ GOARCH string
+ BuildTags []string
+ ReleaseTags []string // Use build.Default if nil.
}
// PackageConfig is serialized as the configuration.
@@ -65,7 +66,8 @@ type PackageConfig struct {
ImportPath string
GoFiles []string
NonGoFiles []string
- Tags []string
+ BuildTags []string
+ ReleaseTags []string // Use build.Default if nil.
GOOS string
GOARCH string
ImportMap map[string]string
@@ -182,7 +184,10 @@ func (c *PackageConfig) shouldInclude(path string) (bool, error) {
ctx := build.Default
ctx.GOOS = c.GOOS
ctx.GOARCH = c.GOARCH
- ctx.BuildTags = c.Tags
+ ctx.BuildTags = c.BuildTags
+ if c.ReleaseTags != nil {
+ ctx.ReleaseTags = c.ReleaseTags
+ }
return ctx.MatchFile(filepath.Dir(path), filepath.Base(path))
}
@@ -293,6 +298,19 @@ func CheckStdlib(config *StdlibConfig, analyzers []*analysis.Analyzer) (allFindi
break
}
+ // Go standard library packages using Go 1.18 type parameter features.
+ //
+ // As of writing, analysis tooling is not updated to support type
+ // parameters and will choke on these packages. We skip these packages
+ // entirely for now.
+ //
+ // TODO(b/201686256): remove once tooling can handle type parameters.
+ usesTypeParams := map[string]struct{}{
+ "constraints": struct{}{}, // golang.org/issue/45458
+ "maps": struct{}{}, // golang.org/issue/47649
+ "slices": struct{}{}, // golang.org/issue/45955
+ }
+
// Aggregate all files by directory.
packages := make(map[string]*PackageConfig)
for _, file := range config.Srcs {
@@ -306,17 +324,25 @@ func CheckStdlib(config *StdlibConfig, analyzers []*analysis.Analyzer) (allFindi
continue // Not a file.
}
pkg := d[len(rootSrcPrefix):]
+
// Skip cmd packages and obvious test files: see above.
if strings.HasPrefix(pkg, "cmd/") || strings.HasSuffix(file, "_test.go") {
continue
}
+
+ if _, ok := usesTypeParams[pkg]; ok {
+ log.Printf("WARNING: Skipping package %q: type param analysis not yet supported", pkg)
+ continue
+ }
+
c, ok := packages[pkg]
if !ok {
c = &PackageConfig{
- ImportPath: pkg,
- GOOS: config.GOOS,
- GOARCH: config.GOARCH,
- Tags: config.Tags,
+ ImportPath: pkg,
+ GOOS: config.GOOS,
+ GOARCH: config.GOARCH,
+ BuildTags: config.BuildTags,
+ ReleaseTags: config.ReleaseTags,
}
packages[pkg] = c
}